Aller au contenu

TP06 : Un générateur de code WASM pour Oberon-0

La dernière étape dans la construction de notre compilateur est d’écrire un générateur de code pour le langage WebAssembly (WASM).

Générateur de code

Voici une proposition d’implémentation du générateur de code (incomplet) en Python:

src/oberon0_compiler/code_gen.py
# SPDX-FileCopyrightText: 2026 Jacques Supcik <jacques.supcik@hefr.ch>
#
# SPDX-License-Identifier: MIT

"""
Oberon-2 WASM Code Generator
"""

from dataclasses import dataclass, field
from typing import BinaryIO

import wasm_gen as W  # noqa
from loguru import logger
from rich.console import Console
from wasm_gen import instructions as I  # noqa
from wasm_gen.type import i32_t

from . import ast
from . import sym_table as SYM
from .scanner import Position

console = Console()

system_calls = [
    (
        "OpenInput",
        W.BaseFunction(type=W.FunctionType(params=[], results=[])),
    ),
    (
        "ReadInt",
        W.BaseFunction(type=W.FunctionType(params=[i32_t], results=[])),
    ),
    (
        "eot",
        W.BaseFunction(type=W.FunctionType(params=[], results=[i32_t])),
    ),
    (
        "WriteChar",
        W.BaseFunction(type=W.FunctionType(params=[i32_t], results=[])),
    ),
    (
        "WriteInt",
        W.BaseFunction(type=W.FunctionType(params=[i32_t, i32_t], results=[])),
    ),
    (
        "WriteLn",
        W.BaseFunction(type=W.FunctionType(params=[], results=[])),
    ),
]


class CodeGenError(Exception):
    def __init__(self, message: str, position: Position) -> None:
        super().__init__(message)
        self.position = position

    def __str__(self) -> str:
        p = self.position
        return (
            f"{self.args[0]} (File {p.file_name}, Line {p.line_no}, Column {p.col_no})"
        )


@dataclass
class CodeGenerator:
    code: W.Module | None = None
    _sp: W.BaseGlobal | None = None
    _current_function: list[W.Function] = field(default_factory=list)

    def ensure(self, node: ast.Node, condition: bool, message: str) -> None:
        if not condition:
            raise CodeGenError(message, node.position)

    def current_function(self) -> W.Function:
        assert len(self._current_function) > 0
        return self._current_function[-1]

    def add_syscalls(self) -> None:
        logger.debug("Adding system calls")
        assert self.code is not None
        for name, func in system_calls:
            self.code.imports.append(
                W.Import(
                    node=func,
                    module="sys",
                    name=name,
                )
            )

    def add_memory(self) -> None:
        assert self.code is not None
        m1 = W.BaseMemory(type=W.MemoryType(min_pages=1))
        self.code.imports.append(W.Import(node=m1, module="env", name="memory"))

    def add_stack_pointer(self) -> None:
        assert self.code is not None
        self._sp = W.BaseGlobal(type=W.GlobalType(type=i32_t, mutable=True))
        self.code.imports.append(
            W.Import(node=self._sp, module="env", name="__stack_pointer")
        )

    def addr_of_symbol(self, node: ast.Node, sym: SYM.Symbol) -> None:
        assert self._sp is not None
        fn = self.current_function()
        if isinstance(sym, SYM.LocalVariable):
            fn.body.extend(
                [
                    I.GlobalGet(global_=self._sp),
                    I.I32Const(value=sym.offset),
                    I.I32Add(),
                ]
            )
        elif isinstance(sym, SYM.GlobalVariable):
            fn.body.extend(
                [
                    I.I32Const(value=sym.offset),
                ]
            )
        elif isinstance(sym, SYM.FormalParameter):
            self.ensure(node, sym.by_ref, "Symbol must be by reference")
            fn.body.extend(
                [
                    I.LocalGet(localidx=sym.index),
                ]
            )
        else:
            raise CodeGenError(
                f"Unknown instance of symbol: {sym} (NOT YET IMPLEMENTED)",
                node.position,
            )

    def addr_of_expr(self, expr: ast.Expression) -> None:
        self.ensure(expr, isinstance(expr, ast.SimpleExpression), "Expression expected")
        assert isinstance(expr, ast.SimpleExpression)
        self.ensure(expr, expr.sign is None, "Sign not allowed")
        self.ensure(
            expr,
            isinstance(expr.term.factor, ast.Ident),
            "Simple factor expected",
        )
        assert isinstance(expr.term.factor, ast.Ident)
        self.ensure(expr, len(expr.term.mulop_factors) == 0, "No mulop factors allowed")
        self.ensure(expr, len(expr.addop_terms) == 0, "No addop terms allowed")

        self.addr_of_symbol(expr, expr.term.factor.symbol)

    def function_call(self, f: ast.FunctionCall) -> None:
        s = f.symbol
        self.ensure(f, isinstance(s, SYM.SystemCall), "Only system calls allowed")
        assert isinstance(s, SYM.SystemCall)
        self.system_call(f, s)

    def factor(self, f: ast.Factor) -> None:
        assert self.code is not None and self._sp is not None
        if isinstance(f, ast.Number):
            logger.debug(f"Number: {f.value}")
            self.current_function().body.append(I.I32Const(value=f.value))
        elif isinstance(f, ast.Ident):
            sym = f.symbol
            if isinstance(sym, SYM.LocalVariable):
                self.current_function().body.extend(
                    [
                        I.GlobalGet(global_=self._sp),
                        I.I32Const(value=sym.offset),
                        I.I32Add(),
                        I.I32Load(),
                    ]
                )
            elif isinstance(sym, SYM.GlobalVariable):
                self.current_function().body.extend(
                    [
                        I.I32Const(value=sym.offset),
                        I.I32Load(),
                    ]
                )
            elif isinstance(sym, SYM.FormalParameter):
                if sym.by_ref:
                    self.current_function().body.extend(
                        [
                            I.LocalGet(localidx=sym.index),
                            I.I32Load(),
                        ]
                    )
                else:
                    self.current_function().body.extend(
                        [
                            I.LocalGet(localidx=sym.index),
                        ]
                    )
            else:
                raise CodeGenError(f"Unknown symbol: {sym}", f.position)

        elif isinstance(f, ast.FunctionCall):
            self.function_call(f)
        elif isinstance(f, ast.Expression):
            self.expression(f.expression)
        else:
            raise CodeGenError(f"Unknown factor: {f}", f.position)

    def assignment(self, a: ast.Assignment) -> None:
        fn = self.current_function()
        sym = a.symbol
        self.ensure(a, sym is not None, f"Unknown symbol: {sym}")
        assert sym is not None
        self.addr_of_symbol(a, sym)
        self.expression(a.expression)

        if isinstance(sym, SYM.FormalParameter) and not sym.by_ref:
            fn.body.append(I.LocalSet(localidx=sym.index))
        else:
            fn.body.append(I.I32Store())

    def system_call(
        self, p: ast.FunctionCall | ast.ProcedureCall, s: SYM.SystemCall
    ) -> None:
        logger.debug(f"System call: {p.symbol.name}")
        self.ensure(p, len(p.params) == len(s.params), "Wrong number of arguments")
        for i, a in enumerate(p.params):
            # TODO: check type
            if s.params[i].by_ref:
                logger.debug(f"argument: {a} by ref")
                self.addr_of_expr(a)
            else:
                logger.debug(f"argument: {a} by val")
                self.expression(a)

        self.current_function().body.append(I.Call(function=system_calls[s.index][1]))

    def procedure(self, p: ast.ProcedureDeclaration) -> None:
        assert self.code is not None and self._sp is not None
        if p.exported:
            self.ensure(
                p,
                len(p.symbol.params) == 0,
                "Exported procedures cannot have parameters",
            )

        f = W.Function(type=W.FunctionType(params=[], results=[]))
        self._current_function.append(f)

        # Procedure preamble (make room for local variables)

        if p.symbol.stack_size > 0:
            f.body.extend(
                [
                    I.GlobalGet(global_=self._sp),
                    I.I32Const(value=p.symbol.stack_size),
                    I.I32Sub(),
                    I.GlobalSet(global_=self._sp),
                ]
            )

        self.statement_sequence(p.body.statements)

        # Procedure postamble (reclaim memory for local variables)
        if p.symbol.stack_size > 0:
            f.body.extend(
                [
                    I.GlobalGet(global_=self._sp),
                    I.I32Const(value=p.symbol.stack_size),
                    I.I32Add(),
                    I.GlobalSet(global_=self._sp),
                ]
            )

        f.body.append(I.End())
        self.code.funcs.append(f)
        if p.exported:
            self.code.exports.append(W.Export(node=f, name=p.symbol.name))

        self._current_function.pop()

    def generate(self, ast_: ast.Module, io: BinaryIO) -> None:

        self.ensure(ast_, isinstance(ast_, ast.Module), "Module expected")
        self.code = W.Module()

        self.add_syscalls()
        self.add_memory()
        self.add_stack_pointer()

        d = ast_.declarations

        for p in d.procedure_declarations:
            self.procedure(p)

        io.write(bytes(self.code))

Complétez le code ci-dessus et testez votre générateur de code en compilant les modules print42.mod, print2x.mod et add.mod que nous avons utilisés dans les TP précédents

print42.mod
(*
SPDX-FileCopyrightText: 2026 Jacques Supcik <jacques.supcik@hefr.ch>
SPDX-License-Identifier: MIT
*)

MODULE Test;

    PROCEDURE Print42*;
    BEGIN
        WriteInt(42, 5);
        WriteLn;
    END Print42;

END Test.
print2x.mod
(*
SPDX-FileCopyrightText: 2026 Jacques Supcik <jacques.supcik@hefr.ch>
SPDX-License-Identifier: MIT
*)

MODULE Test;

    PROCEDURE Print2X*;
    VAR x, y : INTEGER;
    BEGIN
        OpenInput;
        ReadInt(x);
        WriteInt(2 * x, 5);
        WriteLn;
    END Print2X;

END Test.
add.mod
(*
SPDX-FileCopyrightText: 2026 Jacques Supcik <jacques.supcik@hefr.ch>
SPDX-License-Identifier: MIT
*)

MODULE Test;

    PROCEDURE Add*;
        VAR x, y, z: INTEGER;
    BEGIN
        OpenInput;
        ReadInt(x);
        ReadInt(y);
        z := x + y;
        WriteInt(z, 5);
        WriteLn;
    END Add;

END Test.

Exécutez les fichiers WASM générés dans le runtime WASM que vous avez implémenté dans le TP02.

> oberon0-rt run print42.wasm Print42
   42
> oberon0-rt run print2x.wasm Print2x 21
   42
> oberon0-rt run add.wasm Add 31 11
   42

Voici le résultat attendu si on convertit les fichiers wasm en wat

print42.wat
(module
  (type $t0 (func))
  (type $t1 (func (param i32)))
  (type $t2 (func (result i32)))
  (type $t3 (func (param i32 i32)))
  (import "sys" "OpenInput" (func $sys.OpenInput (type $t0)))
  (import "sys" "ReadInt" (func $sys.ReadInt (type $t1)))
  (import "sys" "eot" (func $sys.eot (type $t2)))
  (import "sys" "WriteChar" (func $sys.WriteChar (type $t1)))
  (import "sys" "WriteInt" (func $sys.WriteInt (type $t3)))
  (import "sys" "WriteLn" (func $sys.WriteLn (type $t0)))
  (import "env" "memory" (memory $env.memory 1))
  (import "env" "__stack_pointer" (global $env.__stack_pointer (mut i32)))
  (func $Print42 (type $t0)
    i32.const 42
    i32.const 5
    call $sys.WriteInt
    call $sys.WriteLn)
  (export "Print42" (func $Print42)))
print2x.wat
(module
  (type $t0 (func))
  (type $t1 (func (param i32)))
  (type $t2 (func (result i32)))
  (type $t3 (func (param i32 i32)))
  (import "sys" "OpenInput" (func $sys.OpenInput (type $t0)))
  (import "sys" "ReadInt" (func $sys.ReadInt (type $t1)))
  (import "sys" "eot" (func $sys.eot (type $t2)))
  (import "sys" "WriteChar" (func $sys.WriteChar (type $t1)))
  (import "sys" "WriteInt" (func $sys.WriteInt (type $t3)))
  (import "sys" "WriteLn" (func $sys.WriteLn (type $t0)))
  (import "env" "memory" (memory $env.memory 1))
  (import "env" "__stack_pointer" (global $env.__stack_pointer (mut i32)))
  (func $Print2X (type $t0)
    global.get $env.__stack_pointer
    i32.const 8
    i32.sub
    global.set $env.__stack_pointer
    call $sys.OpenInput
    global.get $env.__stack_pointer
    i32.const 0
    i32.add
    call $sys.ReadInt
    i32.const 2
    global.get $env.__stack_pointer
    i32.const 0
    i32.add
    i32.load
    i32.mul
    i32.const 5
    call $sys.WriteInt
    call $sys.WriteLn
    global.get $env.__stack_pointer
    i32.const 8
    i32.add
    global.set $env.__stack_pointer)
  (export "Print2X" (func $Print2X)))
add.wat
(module
  (type $t0 (func))
  (type $t1 (func (param i32)))
  (type $t2 (func (result i32)))
  (type $t3 (func (param i32 i32)))
  (import "sys" "OpenInput" (func $sys.OpenInput (type $t0)))
  (import "sys" "ReadInt" (func $sys.ReadInt (type $t1)))
  (import "sys" "eot" (func $sys.eot (type $t2)))
  (import "sys" "WriteChar" (func $sys.WriteChar (type $t1)))
  (import "sys" "WriteInt" (func $sys.WriteInt (type $t3)))
  (import "sys" "WriteLn" (func $sys.WriteLn (type $t0)))
  (import "env" "memory" (memory $env.memory 1))
  (import "env" "__stack_pointer" (global $env.__stack_pointer (mut i32)))
  (func $Add (type $t0)
    global.get $env.__stack_pointer
    i32.const 12
    i32.sub
    global.set $env.__stack_pointer
    call $sys.OpenInput
    global.get $env.__stack_pointer
    i32.const 0
    i32.add
    call $sys.ReadInt
    global.get $env.__stack_pointer
    i32.const 4
    i32.add
    call $sys.ReadInt
    global.get $env.__stack_pointer
    i32.const 8
    i32.add
    global.get $env.__stack_pointer
    i32.const 0
    i32.add
    i32.load
    global.get $env.__stack_pointer
    i32.const 4
    i32.add
    i32.load
    i32.add
    i32.store
    global.get $env.__stack_pointer
    i32.const 8
    i32.add
    i32.load
    i32.const 5
    call $sys.WriteInt
    call $sys.WriteLn
    global.get $env.__stack_pointer
    i32.const 12
    i32.add
    global.set $env.__stack_pointer)
  (export "Add" (func $Add)))

Tests unitaires

Complétez vos tests unitaires pour le générateur de code.