diff python/ppci/c3/codegenerator.py @ 308:2e7f55319858

Merged analyse into codegenerator
author Windel Bouwman
date Fri, 13 Dec 2013 11:53:29 +0100
parents e609d5296ee9
children ff665880a6b0
line wrap: on
line diff
--- a/python/ppci/c3/codegenerator.py	Thu Dec 12 20:42:56 2013 +0100
+++ b/python/ppci/c3/codegenerator.py	Fri Dec 13 11:53:29 2013 +0100
@@ -1,18 +1,13 @@
 import logging
 from .. import ir
 from .. import irutils
-from .astnodes import Symbol, Package, Variable, Function
-from .astnodes import Statement, Empty, Compound, If, While, Assignment
-from .astnodes import ExpressionStatement, Return
-from .astnodes import Expression, Binop, Unop, Identifier, Deref, Member
-from .astnodes import Expression, FunctionCall, Literal, TypeCast
-from .astnodes import Type, DefinedType, BaseType, PointerType, StructureType
-
-from ppci import CompilerError
+from . import astnodes as ast
 
 
 class SemanticError(Exception):
+    """ Error thrown when a semantic issue is observed """
     def __init__(self, msg, loc):
+        super().__init__()
         self.msg = msg
         self.loc = loc
 
@@ -33,8 +28,9 @@
         self.diag = diag
 
     def gencode(self, pkg):
+        """ Generate code for a single module """
         self.prepare()
-        assert type(pkg) is Package
+        assert type(pkg) is ast.Package
         self.pkg = pkg
         self.intType = pkg.scope['int']
         self.boolType = pkg.scope['bool']
@@ -42,16 +38,6 @@
         self.varMap = {}    # Maps variables to storage locations.
         self.funcMap = {}
         self.m = ir.Module(pkg.name)
-        self.genModule(pkg)
-        return self.m
-
-    def error(self, msg, loc=None):
-        self.pkg.ok = False
-        self.diag.error(msg, loc)
-
-    # inner helpers:
-    def genModule(self, pkg):
-        # Take care of forward declarations:
         try:
             for s in pkg.innerScope.Functions:
                 f = self.newFunction(s.name)
@@ -59,15 +45,20 @@
             for v in pkg.innerScope.Variables:
                 self.varMap[v] = self.newTemp()
             for s in pkg.innerScope.Functions:
-                self.genFunction(s)
+                self.gen_function(s)
         except SemanticError as e:
             self.error(e.msg, e.loc)
+        return self.m
+
+    def error(self, msg, loc=None):
+        self.pkg.ok = False
+        self.diag.error(msg, loc)
 
     def checkType(self, t):
         """ Verify the type is correct """
         t = self.theType(t)
 
-    def genFunction(self, fn):
+    def gen_function(self, fn):
         # TODO: handle arguments
         f = self.funcMap[fn]
         f.return_value = self.newTemp()
@@ -81,43 +72,42 @@
             # TODO: handle parameters different
             if sym.isParameter:
                 self.checkType(sym.typ)
-                v = ir.Parameter(sym.name)
-                f.addParameter(v)
+                variable = ir.Parameter(sym.name)
+                f.addParameter(variable)
             elif sym.isLocal:
                 self.checkType(sym.typ)
-                v = ir.LocalVariable(sym.name)
-                f.addLocal(v)
-            elif isinstance(sym, Variable):
+                variable = ir.LocalVariable(sym.name)
+                f.addLocal(variable)
+            elif isinstance(sym, ast.Variable):
                 self.checkType(sym.typ)
-                v = ir.LocalVariable(sym.name)
-                f.addLocal(v)
+                variable = ir.LocalVariable(sym.name)
+                f.addLocal(variable)
             else:
-                #v = self.newTemp()
                 raise NotImplementedError('{}'.format(sym))
-            self.varMap[sym] = v
+            self.varMap[sym] = variable
 
         self.genCode(fn.body)
-        # Set the default return value to zero:
-        # TBD: this may not be required?
         self.emit(ir.Move(f.return_value, ir.Const(0)))
         self.emit(ir.Jump(f.epiloog))
         self.setFunction(None)
 
     def genCode(self, code):
+        """ Wrapper around gen_stmt to catch errors """
         try:
-            self.genStmt(code)
+            self.gen_stmt(code)
         except SemanticError as e:
             self.error(e.msg, e.loc)
 
-    def genStmt(self, code):
-        assert isinstance(code, Statement)
+    def gen_stmt(self, code):
+        """ Generate code for a statement """
+        assert isinstance(code, ast.Statement)
         self.setLoc(code.loc)
-        if type(code) is Compound:
+        if type(code) is ast.Compound:
             for s in code.statements:
                 self.genCode(s)
-        elif type(code) is Empty:
+        elif type(code) is ast.Empty:
             pass
-        elif type(code) is Assignment:
+        elif type(code) is ast.Assignment:
             lval = self.genExprCode(code.lval)
             rval = self.genExprCode(code.rval)
             if not self.equalTypes(code.lval.typ, code.rval.typ):
@@ -126,13 +116,13 @@
             if not code.lval.lvalue:
                 raise SemanticError('No valid lvalue {}'.format(code.lval), code.lval.loc)
             self.emit(ir.Move(lval, rval))
-        elif type(code) is ExpressionStatement:
+        elif type(code) is ast.ExpressionStatement:
             self.emit(ir.Exp(self.genExprCode(code.ex)))
-        elif type(code) is If:
+        elif type(code) is ast.If:
             bbtrue = self.newBlock()
             bbfalse = self.newBlock()
             te = self.newBlock()
-            self.genCondCode(code.condition, bbtrue, bbfalse)
+            self.gen_cond_code(code.condition, bbtrue, bbfalse)
             self.setBlock(bbtrue)
             self.genCode(code.truestatement)
             self.emit(ir.Jump(te))
@@ -140,19 +130,19 @@
             self.genCode(code.falsestatement)
             self.emit(ir.Jump(te))
             self.setBlock(te)
-        elif type(code) is Return:
+        elif type(code) is ast.Return:
             re = self.genExprCode(code.expr)
             self.emit(ir.Move(self.fn.return_value, re))
             self.emit(ir.Jump(self.fn.epiloog))
             b = self.newBlock()
             self.setBlock(b)
-        elif type(code) is While:
+        elif type(code) is ast.While:
             bbdo = self.newBlock()
             bbtest = self.newBlock()
             te = self.newBlock()
             self.emit(ir.Jump(bbtest))
             self.setBlock(bbtest)
-            self.genCondCode(code.condition, bbdo, te)
+            self.gen_cond_code(code.condition, bbdo, te)
             self.setBlock(bbdo)
             self.genCode(code.statement)
             self.emit(ir.Jump(bbtest))
@@ -160,26 +150,27 @@
         else:
             raise NotImplementedError('Unknown stmt {}'.format(code))
 
-    def genCondCode(self, expr, bbtrue, bbfalse):
-        # Implement sequential logical operators
-        if type(expr) is Binop:
+    def gen_cond_code(self, expr, bbtrue, bbfalse):
+        """ Generate conditional logic.
+            Implement sequential logical operators. """
+        if type(expr) is ast.Binop:
             if expr.op == 'or':
                 l2 = self.newBlock()
-                self.genCondCode(expr.a, bbtrue, l2)
-                if not equalTypes(expr.a.typ, self.boolType):
+                self.gen_cond_code(expr.a, bbtrue, l2)
+                if not self.equalTypes(expr.a.typ, self.boolType):
                     raise SemanticError('Must be boolean', expr.a.loc)
                 self.setBlock(l2)
-                self.genCondCode(expr.b, bbtrue, bbfalse)
-                if not equalTypes(expr.b.typ, self.boolType):
+                self.gen_cond_code(expr.b, bbtrue, bbfalse)
+                if not self.equalTypes(expr.b.typ, self.boolType):
                     raise SemanticError('Must be boolean', expr.b.loc)
             elif expr.op == 'and':
                 l2 = self.newBlock()
-                self.genCondCode(expr.a, l2, bbfalse)
-                if not equalTypes(expr.a.typ, self.boolType):
+                self.gen_cond_code(expr.a, l2, bbfalse)
+                if not self.equalTypes(expr.a.typ, self.boolType):
                     self.error('Must be boolean', expr.a.loc)
                 self.setBlock(l2)
-                self.genCondCode(expr.b, bbtrue, bbfalse)
-                if not equalTypes(expr.b.typ, self.boolType):
+                self.gen_cond_code(expr.b, bbtrue, bbfalse)
+                if not self.equalTypes(expr.b.typ, self.boolType):
                     raise SemanticError('Must be boolean', expr.b.loc)
             elif expr.op in ['==', '>', '<', '!=', '<=', '>=']:
                 ta = self.genExprCode(expr.a)
@@ -191,20 +182,21 @@
             else:
                 raise NotImplementedError('Unknown condition {}'.format(expr))
             expr.typ = self.boolType
-        elif type(expr) is Literal:
+        elif type(expr) is ast.Literal:
+            self.genExprCode(expr)
             if expr.val:
                 self.emit(ir.Jump(bbtrue))
             else:
                 self.emit(ir.Jump(bbfalse))
-            expr.typ = self.boolType
         else:
             raise NotImplementedError('Unknown cond {}'.format(expr))
         if not self.equalTypes(expr.typ, self.boolType):
             self.error('Condition must be boolean', expr.loc)
 
     def genExprCode(self, expr):
-        assert isinstance(expr, Expression)
-        if type(expr) is Binop:
+        """ Generate code for an expression. Return the generated ir-value """
+        assert isinstance(expr, ast.Expression)
+        if type(expr) is ast.Binop:
             expr.lvalue = False
             if expr.op in ['+', '-', '*', '/', '<<', '>>', '|', '&']:
                 ra = self.genExprCode(expr.a)
@@ -213,15 +205,14 @@
                         self.equalTypes(expr.b.typ, self.intType):
                     expr.typ = expr.a.typ
                 else:
-                    # assume void here? TODO: throw exception!
                     raise SemanticError('Can only add integers', expr.loc)
             else:
                 raise NotImplementedError("Cannot use equality as expressions")
             return ir.Binop(ra, expr.op, rb)
-        elif type(expr) is Unop:
+        elif type(expr) is ast.Unop:
             if expr.op == '&':
                 ra = self.genExprCode(expr.a)
-                expr.typ = PointerType(expr.a.typ)
+                expr.typ = ast.PointerType(expr.a.typ)
                 if not expr.a.lvalue:
                     raise SemanticError('No valid lvalue', expr.a.loc)
                 expr.lvalue = False
@@ -229,50 +220,46 @@
                 return ra.e
             else:
                 raise NotImplementedError('Unknown unop {0}'.format(expr.op))
-        elif type(expr) is Identifier:
+        elif type(expr) is ast.Identifier:
             # Generate code for this identifier.
             expr.lvalue = True
             tg = self.resolveSymbol(expr)
             expr.kind = type(tg)
             expr.typ = tg.typ
             # This returns the dereferenced variable.
-            if type(tg) is Variable:
-                # TODO: now parameters are handled different. Not nice?
+            if isinstance(tg, ast.Variable):
                 return ir.Mem(self.varMap[tg])
             else:
-                return ir.Mem(self.varMap[tg])
-        elif type(expr) is Deref:
+                raise NotImplementedError(str(tg))
+        elif type(expr) is ast.Deref:
             # dereference pointer type:
             addr = self.genExprCode(expr.ptr)
             ptr_typ = self.theType(expr.ptr.typ)
             expr.lvalue = True
-            if type(ptr_typ) is PointerType:
+            if type(ptr_typ) is ast.PointerType:
                 expr.typ = ptr_typ.ptype
                 return ir.Mem(addr)
             else:
                 raise SemanticError('Cannot deref non-pointer', expr.loc)
-                expr.typ = self.intType
-                return ir.Mem(ir.Const(0))
-        elif type(expr) is Member:
+        elif type(expr) is ast.Member:
             base = self.genExprCode(expr.base)
             expr.lvalue = expr.base.lvalue
             basetype = self.theType(expr.base.typ)
-            if type(basetype) is StructureType:
+            if type(basetype) is ast.StructureType:
                 if basetype.hasField(expr.field):
                     expr.typ = basetype.fieldType(expr.field)
                 else:
                     raise SemanticError('{} does not contain field {}'
-                               .format(basetype, expr.field), expr.loc)
+                                        .format(basetype, expr.field), expr.loc)
             else:
-                raise SemanticError('Cannot select field {} of non-structure type {}'
-                           .format(sym.field, basetype), sym.loc)
+                raise SemanticError('Cannot select {} of non-structure type {}'
+                                    .format(expr.field, basetype), expr.loc)
 
             assert type(base) is ir.Mem, type(base)
-            base = base.e
             bt = self.theType(expr.base.typ)
             offset = ir.Const(bt.fieldOffset(expr.field))
-            return ir.Mem(ir.Add(base, offset))
-        elif type(expr) is Literal:
+            return ir.Mem(ir.Add(base.e, offset))
+        elif type(expr) is ast.Literal:
             expr.lvalue = False
             typemap = {int: 'int', float: 'double', bool: 'bool'}
             if type(expr.val) in typemap:
@@ -280,55 +267,59 @@
             else:
                 raise SemanticError('Unknown literal type {}'.format(expr.val))
             return ir.Const(expr.val)
-        elif type(expr) is TypeCast:
-            # TODO: improve this mess:
-            ar = self.genExprCode(expr.a)
-            ft = self.theType(expr.a.typ)
-            tt = self.theType(expr.to_type)
-            if isinstance(ft, PointerType) and isinstance(tt, PointerType):
-                expr.typ = expr.to_type
-                return ar
-            elif type(ft) is BaseType and ft.name == 'int' and \
-                    isinstance(tt, PointerType):
-                expr.typ = expr.to_type
-                return ar
-            else:
-                self.error('Cannot cast {} to {}'
-                           .format(ft, tt), expr.loc)
-                expr.typ = self.intType
-                return ir.Const(0)
-        elif type(expr) is FunctionCall:
-            # Evaluate the arguments:
-            args = [self.genExprCode(e) for e in expr.args]
-            # Check arguments:
-            if type(expr.proc) is Identifier:
-                tg = self.resolveSymbol(expr.proc)
-            else:
-                raise Exception()
-            assert type(tg) is Function
-            ftyp = tg.typ
-            fname = tg.name
-            ptypes = ftyp.parametertypes
-            if len(expr.args) != len(ptypes):
-                raise SemanticError('Function {2}: {0} arguments required, {1} given'
-                           .format(len(ptypes), len(expr.args), fname), expr.loc)
-            for arg, at in zip(expr.args, ptypes):
-                if not self.equalTypes(arg.typ, at):
-                    raise SemanticError('Got {0}, expected {1}'
-                               .format(arg.typ, at), arg.loc)
-            # determine return type:
-            expr.typ = ftyp.returntype
-            return ir.Call(fname, args)
+        elif type(expr) is ast.TypeCast:
+            return self.gen_type_cast(expr)
+        elif type(expr) is ast.FunctionCall:
+            return self.gen_function_call(expr)
         else:
             raise NotImplementedError('Unknown expr {}'.format(expr))
 
+    def gen_type_cast(self, expr):
+        """ Generate code for type casting """
+        ar = self.genExprCode(expr.a)
+        from_type = self.theType(expr.a.typ)
+        to_type = self.theType(expr.to_type)
+        if isinstance(from_type, ast.PointerType) and isinstance(to_type, ast.PointerType):
+            expr.typ = expr.to_type
+            return ar
+        elif type(from_type) is ast.BaseType and from_type.name == 'int' and \
+                isinstance(to_type, ast.PointerType):
+            expr.typ = expr.to_type
+            return ar
+        else:
+            raise SemanticError('Cannot cast {} to {}'
+                                .format(from_type, to_type), expr.loc)
+ 
+    def gen_function_call(self, expr):
+        """ Generate code for a function call """
+        # Evaluate the arguments:
+        args = [self.genExprCode(e) for e in expr.args]
+        # Check arguments:
+        tg = self.resolveSymbol(expr.proc)
+        if type(tg) is not ast.Function:
+            raise SemanticError('cannot call {}'.format(tg))
+        ftyp = tg.typ
+        fname = tg.name
+        ptypes = ftyp.parametertypes
+        if len(expr.args) != len(ptypes):
+            raise SemanticError('{} requires {} arguments, {} given'
+                       .format(fname, len(ptypes), len(expr.args)), expr.loc)
+        for arg, at in zip(expr.args, ptypes):
+            if not self.equalTypes(arg.typ, at):
+                raise SemanticError('Got {}, expected {}'
+                           .format(arg.typ, at), arg.loc)
+        # determine return type:
+        expr.typ = ftyp.returntype
+        return ir.Call(fname, args)
+
     def resolveSymbol(self, sym):
-        assert type(sym) in [Identifier, Member]
-        if type(sym) is Member:
+        if type(sym) is ast.Member:
             base = self.resolveSymbol(sym.base)
+            if type(base) is not ast.Package:
+                raise SemanticError('Base is not a package', sym.loc)
             scope = base.innerScope
             name = sym.field
-        elif type(sym) is Identifier:
+        elif type(sym) is ast.Identifier:
             scope = sym.scope
             name = sym.target
         else:
@@ -337,23 +328,20 @@
             s = scope[name]
         else:
             raise SemanticError('{} undefined'.format(name), sym.loc)
-        assert isinstance(s, Symbol)
+        assert isinstance(s, ast.Symbol)
         return s
 
     def theType(self, t):
         """ Recurse until a 'real' type is found """
-        if type(t) is DefinedType:
+        if type(t) is ast.DefinedType:
             t = self.theType(t.typ)
-        elif type(t) is Identifier:
+        elif type(t) in [ast.Identifier, ast.Member]:
             t = self.theType(self.resolveSymbol(t))
-        elif type(t) is Member:
-            # Possible when using types of modules:
-            t = self.theType(self.resolveSymbol(t))
-        elif isinstance(t, Type):
+        elif isinstance(t, ast.Type):
             pass
         else:
             raise NotImplementedError(str(t))
-        assert isinstance(t, Type)
+        assert isinstance(t, ast.Type)
         return t
 
     def equalTypes(self, a, b):
@@ -361,15 +349,15 @@
         # Recurse into named types:
         a = self.theType(a)
         b = self.theType(b)
-        assert isinstance(a, Type)
-        assert isinstance(b, Type)
+        assert isinstance(a, ast.Type)
+        assert isinstance(b, ast.Type)
 
         if type(a) is type(b):
-            if type(a) is BaseType:
+            if type(a) is ast.BaseType:
                 return a.name == b.name
-            elif type(a) is PointerType:
+            elif type(a) is ast.PointerType:
                 return self.equalTypes(a.ptype, b.ptype)
-            elif type(a) is StructureType:
+            elif type(a) is ast.StructureType:
                 if len(a.mems) != len(b.mems):
                     return False
                 return all(self.equalTypes(am.typ, bm.typ) for am, bm in