diff python/ppci/c3/codegenerator.py @ 307:e609d5296ee9

Massive rewrite of codegenerator
author Windel Bouwman
date Thu, 12 Dec 2013 20:42:56 +0100
parents b145f8e6050b
children 2e7f55319858
line wrap: on
line diff
--- a/python/ppci/c3/codegenerator.py	Mon Dec 09 19:00:21 2013 +0100
+++ b/python/ppci/c3/codegenerator.py	Thu Dec 12 20:42:56 2013 +0100
@@ -1,11 +1,23 @@
 import logging
 from .. import ir
-from . import astnodes
+from .. import irutils
+from .astnodes import Symbol, Package, Variable, Function
+from .astnodes import Statement, Empty, Compound, If, While, Assignment
+from .astnodes import ExpressionStatement, Return
+from .astnodes import Expression, Binop, Unop, Identifier, Deref, Member
+from .astnodes import Expression, FunctionCall, Literal, TypeCast
+from .astnodes import Type, DefinedType, BaseType, PointerType, StructureType
+
 from ppci import CompilerError
-from .analyse import theType
 
 
-class CodeGenerator(ir.Builder):
+class SemanticError(Exception):
+    def __init__(self, msg, loc):
+        self.msg = msg
+        self.loc = loc
+
+
+class CodeGenerator(irutils.Builder):
     """
       Generates intermediate (IR) code from a package. The entry function is
       'genModule'. The main task of this part is to rewrite complex control
@@ -13,13 +25,19 @@
       jump statements. Also complex conditional statements are simplified.
       Such as 'and' and 'or' statements are rewritten in conditional jumps.
       And structured datatypes are rewritten.
+
+      Type checking is done in one run with code generation.
     """
-    def __init__(self):
+    def __init__(self, diag):
         self.logger = logging.getLogger('c3cgen')
+        self.diag = diag
 
     def gencode(self, pkg):
         self.prepare()
-        assert type(pkg) is astnodes.Package
+        assert type(pkg) is Package
+        self.pkg = pkg
+        self.intType = pkg.scope['int']
+        self.boolType = pkg.scope['bool']
         self.logger.info('Generating ir-code for {}'.format(pkg.name))
         self.varMap = {}    # Maps variables to storage locations.
         self.funcMap = {}
@@ -27,16 +45,27 @@
         self.genModule(pkg)
         return self.m
 
+    def error(self, msg, loc=None):
+        self.pkg.ok = False
+        self.diag.error(msg, loc)
+
     # inner helpers:
     def genModule(self, pkg):
         # Take care of forward declarations:
-        for s in pkg.innerScope.Functions:
-            f = self.newFunction(s.name)
-            self.funcMap[s] = f
-        for v in pkg.innerScope.Variables:
-            self.varMap[v] = self.newTemp()
-        for s in pkg.innerScope.Functions:
-            self.genFunction(s)
+        try:
+            for s in pkg.innerScope.Functions:
+                f = self.newFunction(s.name)
+                self.funcMap[s] = f
+            for v in pkg.innerScope.Variables:
+                self.varMap[v] = self.newTemp()
+            for s in pkg.innerScope.Functions:
+                self.genFunction(s)
+        except SemanticError as e:
+            self.error(e.msg, e.loc)
+
+    def checkType(self, t):
+        """ Verify the type is correct """
+        t = self.theType(t)
 
     def genFunction(self, fn):
         # TODO: handle arguments
@@ -51,15 +80,20 @@
         for sym in fn.innerScope:
             # TODO: handle parameters different
             if sym.isParameter:
+                self.checkType(sym.typ)
                 v = ir.Parameter(sym.name)
                 f.addParameter(v)
             elif sym.isLocal:
+                self.checkType(sym.typ)
+                v = ir.LocalVariable(sym.name)
+                f.addLocal(v)
+            elif isinstance(sym, Variable):
+                self.checkType(sym.typ)
                 v = ir.LocalVariable(sym.name)
                 f.addLocal(v)
             else:
                 #v = self.newTemp()
                 raise NotImplementedError('{}'.format(sym))
-            # TODO: make this ssa here??
             self.varMap[sym] = v
 
         self.genCode(fn.body)
@@ -70,20 +104,31 @@
         self.setFunction(None)
 
     def genCode(self, code):
-        assert isinstance(code, astnodes.Statement)
+        try:
+            self.genStmt(code)
+        except SemanticError as e:
+            self.error(e.msg, e.loc)
+
+    def genStmt(self, code):
+        assert isinstance(code, Statement)
         self.setLoc(code.loc)
-        if type(code) is astnodes.CompoundStatement:
+        if type(code) is Compound:
             for s in code.statements:
                 self.genCode(s)
-        elif type(code) is astnodes.EmptyStatement:
+        elif type(code) is Empty:
             pass
-        elif type(code) is astnodes.Assignment:
-            rval = self.genExprCode(code.rval)
+        elif type(code) is Assignment:
             lval = self.genExprCode(code.lval)
+            rval = self.genExprCode(code.rval)
+            if not self.equalTypes(code.lval.typ, code.rval.typ):
+                msg = 'Cannot assign {} to {}'.format(code.lval.typ, code.rval.typ)
+                raise SemanticError(msg, code.loc)
+            if not code.lval.lvalue:
+                raise SemanticError('No valid lvalue {}'.format(code.lval), code.lval.loc)
             self.emit(ir.Move(lval, rval))
-        elif type(code) is astnodes.ExpressionStatement:
+        elif type(code) is ExpressionStatement:
             self.emit(ir.Exp(self.genExprCode(code.ex)))
-        elif type(code) is astnodes.IfStatement:
+        elif type(code) is If:
             bbtrue = self.newBlock()
             bbfalse = self.newBlock()
             te = self.newBlock()
@@ -95,13 +140,13 @@
             self.genCode(code.falsestatement)
             self.emit(ir.Jump(te))
             self.setBlock(te)
-        elif type(code) is astnodes.ReturnStatement:
+        elif type(code) is Return:
             re = self.genExprCode(code.expr)
             self.emit(ir.Move(self.fn.return_value, re))
             self.emit(ir.Jump(self.fn.epiloog))
             b = self.newBlock()
             self.setBlock(b)
-        elif type(code) is astnodes.WhileStatement:
+        elif type(code) is While:
             bbdo = self.newBlock()
             bbtest = self.newBlock()
             te = self.newBlock()
@@ -117,78 +162,218 @@
 
     def genCondCode(self, expr, bbtrue, bbfalse):
         # Implement sequential logical operators
-        if type(expr) is astnodes.Binop:
+        if type(expr) is Binop:
             if expr.op == 'or':
                 l2 = self.newBlock()
                 self.genCondCode(expr.a, bbtrue, l2)
+                if not equalTypes(expr.a.typ, self.boolType):
+                    raise SemanticError('Must be boolean', expr.a.loc)
                 self.setBlock(l2)
                 self.genCondCode(expr.b, bbtrue, bbfalse)
+                if not equalTypes(expr.b.typ, self.boolType):
+                    raise SemanticError('Must be boolean', expr.b.loc)
             elif expr.op == 'and':
                 l2 = self.newBlock()
                 self.genCondCode(expr.a, l2, bbfalse)
+                if not equalTypes(expr.a.typ, self.boolType):
+                    self.error('Must be boolean', expr.a.loc)
                 self.setBlock(l2)
                 self.genCondCode(expr.b, bbtrue, bbfalse)
+                if not equalTypes(expr.b.typ, self.boolType):
+                    raise SemanticError('Must be boolean', expr.b.loc)
             elif expr.op in ['==', '>', '<', '!=', '<=', '>=']:
                 ta = self.genExprCode(expr.a)
                 tb = self.genExprCode(expr.b)
+                if not self.equalTypes(expr.a.typ, expr.b.typ):
+                    raise SemanticError('Types unequal {} != {}'
+                               .format(expr.a.typ, expr.b.typ), expr.loc)
                 self.emit(ir.CJump(ta, expr.op, tb, bbtrue, bbfalse))
             else:
                 raise NotImplementedError('Unknown condition {}'.format(expr))
-        elif type(expr) is astnodes.Literal:
+            expr.typ = self.boolType
+        elif type(expr) is Literal:
             if expr.val:
                 self.emit(ir.Jump(bbtrue))
             else:
                 self.emit(ir.Jump(bbfalse))
+            expr.typ = self.boolType
         else:
             raise NotImplementedError('Unknown cond {}'.format(expr))
+        if not self.equalTypes(expr.typ, self.boolType):
+            self.error('Condition must be boolean', expr.loc)
 
     def genExprCode(self, expr):
-        assert isinstance(expr, astnodes.Expression)
-        if type(expr) is astnodes.Binop and expr.op in ir.Binop.ops:
-            ra = self.genExprCode(expr.a)
-            rb = self.genExprCode(expr.b)
+        assert isinstance(expr, Expression)
+        if type(expr) is Binop:
+            expr.lvalue = False
+            if expr.op in ['+', '-', '*', '/', '<<', '>>', '|', '&']:
+                ra = self.genExprCode(expr.a)
+                rb = self.genExprCode(expr.b)
+                if self.equalTypes(expr.a.typ, self.intType) and \
+                        self.equalTypes(expr.b.typ, self.intType):
+                    expr.typ = expr.a.typ
+                else:
+                    # assume void here? TODO: throw exception!
+                    raise SemanticError('Can only add integers', expr.loc)
+            else:
+                raise NotImplementedError("Cannot use equality as expressions")
             return ir.Binop(ra, expr.op, rb)
-        elif type(expr) is astnodes.Unop and expr.op == '&':
-            ra = self.genExprCode(expr.a)
-            assert type(ra) is ir.Mem
-            return ra.e
-        elif type(expr) is astnodes.VariableUse:
+        elif type(expr) is Unop:
+            if expr.op == '&':
+                ra = self.genExprCode(expr.a)
+                expr.typ = PointerType(expr.a.typ)
+                if not expr.a.lvalue:
+                    raise SemanticError('No valid lvalue', expr.a.loc)
+                expr.lvalue = False
+                assert type(ra) is ir.Mem
+                return ra.e
+            else:
+                raise NotImplementedError('Unknown unop {0}'.format(expr.op))
+        elif type(expr) is Identifier:
+            # Generate code for this identifier.
+            expr.lvalue = True
+            tg = self.resolveSymbol(expr)
+            expr.kind = type(tg)
+            expr.typ = tg.typ
             # This returns the dereferenced variable.
-            if expr.target.isParameter:
+            if type(tg) is Variable:
                 # TODO: now parameters are handled different. Not nice?
-                return self.varMap[expr.target]
+                return ir.Mem(self.varMap[tg])
             else:
-                return ir.Mem(self.varMap[expr.target])
-        elif type(expr) is astnodes.Deref:
+                return ir.Mem(self.varMap[tg])
+        elif type(expr) is Deref:
             # dereference pointer type:
             addr = self.genExprCode(expr.ptr)
-            return ir.Mem(addr)
-        elif type(expr) is astnodes.FieldRef:
+            ptr_typ = self.theType(expr.ptr.typ)
+            expr.lvalue = True
+            if type(ptr_typ) is PointerType:
+                expr.typ = ptr_typ.ptype
+                return ir.Mem(addr)
+            else:
+                raise SemanticError('Cannot deref non-pointer', expr.loc)
+                expr.typ = self.intType
+                return ir.Mem(ir.Const(0))
+        elif type(expr) is Member:
             base = self.genExprCode(expr.base)
+            expr.lvalue = expr.base.lvalue
+            basetype = self.theType(expr.base.typ)
+            if type(basetype) is StructureType:
+                if basetype.hasField(expr.field):
+                    expr.typ = basetype.fieldType(expr.field)
+                else:
+                    raise SemanticError('{} does not contain field {}'
+                               .format(basetype, expr.field), expr.loc)
+            else:
+                raise SemanticError('Cannot select field {} of non-structure type {}'
+                           .format(sym.field, basetype), sym.loc)
+
             assert type(base) is ir.Mem, type(base)
             base = base.e
-            bt = theType(expr.base.typ)
+            bt = self.theType(expr.base.typ)
             offset = ir.Const(bt.fieldOffset(expr.field))
             return ir.Mem(ir.Add(base, offset))
-        elif type(expr) is astnodes.Literal:
+        elif type(expr) is Literal:
+            expr.lvalue = False
+            typemap = {int: 'int', float: 'double', bool: 'bool'}
+            if type(expr.val) in typemap:
+                expr.typ = self.pkg.scope[typemap[type(expr.val)]]
+            else:
+                raise SemanticError('Unknown literal type {}'.format(expr.val))
             return ir.Const(expr.val)
-        elif type(expr) is astnodes.TypeCast:
+        elif type(expr) is TypeCast:
             # TODO: improve this mess:
             ar = self.genExprCode(expr.a)
-            tt = theType(expr.to_type)
-            if isinstance(tt, astnodes.PointerType):
-                if expr.a.typ is expr.scope['int']:
-                    return ar
-                elif isinstance(expr.a.typ, astnodes.PointerType):
-                    return ar
-                else:
-                    raise Exception()
+            ft = self.theType(expr.a.typ)
+            tt = self.theType(expr.to_type)
+            if isinstance(ft, PointerType) and isinstance(tt, PointerType):
+                expr.typ = expr.to_type
+                return ar
+            elif type(ft) is BaseType and ft.name == 'int' and \
+                    isinstance(tt, PointerType):
+                expr.typ = expr.to_type
+                return ar
             else:
-                raise NotImplementedError("not implemented")
-        elif type(expr) is astnodes.FunctionCall:
+                self.error('Cannot cast {} to {}'
+                           .format(ft, tt), expr.loc)
+                expr.typ = self.intType
+                return ir.Const(0)
+        elif type(expr) is FunctionCall:
+            # Evaluate the arguments:
             args = [self.genExprCode(e) for e in expr.args]
-            #fn = self.funcMap[expr.proc]
-            fn = expr.proc.name
-            return ir.Call(fn, args)
+            # Check arguments:
+            if type(expr.proc) is Identifier:
+                tg = self.resolveSymbol(expr.proc)
+            else:
+                raise Exception()
+            assert type(tg) is Function
+            ftyp = tg.typ
+            fname = tg.name
+            ptypes = ftyp.parametertypes
+            if len(expr.args) != len(ptypes):
+                raise SemanticError('Function {2}: {0} arguments required, {1} given'
+                           .format(len(ptypes), len(expr.args), fname), expr.loc)
+            for arg, at in zip(expr.args, ptypes):
+                if not self.equalTypes(arg.typ, at):
+                    raise SemanticError('Got {0}, expected {1}'
+                               .format(arg.typ, at), arg.loc)
+            # determine return type:
+            expr.typ = ftyp.returntype
+            return ir.Call(fname, args)
         else:
             raise NotImplementedError('Unknown expr {}'.format(expr))
+
+    def resolveSymbol(self, sym):
+        assert type(sym) in [Identifier, Member]
+        if type(sym) is Member:
+            base = self.resolveSymbol(sym.base)
+            scope = base.innerScope
+            name = sym.field
+        elif type(sym) is Identifier:
+            scope = sym.scope
+            name = sym.target
+        else:
+            raise NotImplementedError(str(sym))
+        if name in scope:
+            s = scope[name]
+        else:
+            raise SemanticError('{} undefined'.format(name), sym.loc)
+        assert isinstance(s, Symbol)
+        return s
+
+    def theType(self, t):
+        """ Recurse until a 'real' type is found """
+        if type(t) is DefinedType:
+            t = self.theType(t.typ)
+        elif type(t) is Identifier:
+            t = self.theType(self.resolveSymbol(t))
+        elif type(t) is Member:
+            # Possible when using types of modules:
+            t = self.theType(self.resolveSymbol(t))
+        elif isinstance(t, Type):
+            pass
+        else:
+            raise NotImplementedError(str(t))
+        assert isinstance(t, Type)
+        return t
+
+    def equalTypes(self, a, b):
+        """ Compare types a and b for structural equavalence. """
+        # Recurse into named types:
+        a = self.theType(a)
+        b = self.theType(b)
+        assert isinstance(a, Type)
+        assert isinstance(b, Type)
+
+        if type(a) is type(b):
+            if type(a) is BaseType:
+                return a.name == b.name
+            elif type(a) is PointerType:
+                return self.equalTypes(a.ptype, b.ptype)
+            elif type(a) is StructureType:
+                if len(a.mems) != len(b.mems):
+                    return False
+                return all(self.equalTypes(am.typ, bm.typ) for am, bm in
+                           zip(a.mems, b.mems))
+            else:
+                raise NotImplementedError('{} not implemented'.format(type(a)))
+        return False