view python/ppci/c3/codegenerator.py @ 307:e609d5296ee9

Massive rewrite of codegenerator
author Windel Bouwman
date Thu, 12 Dec 2013 20:42:56 +0100
parents b145f8e6050b
children 2e7f55319858
line wrap: on
line source

import logging
from .. import ir
from .. import irutils
from .astnodes import Symbol, Package, Variable, Function
from .astnodes import Statement, Empty, Compound, If, While, Assignment
from .astnodes import ExpressionStatement, Return
from .astnodes import Expression, Binop, Unop, Identifier, Deref, Member
from .astnodes import Expression, FunctionCall, Literal, TypeCast
from .astnodes import Type, DefinedType, BaseType, PointerType, StructureType

from ppci import CompilerError


class SemanticError(Exception):
    def __init__(self, msg, loc):
        self.msg = msg
        self.loc = loc


class CodeGenerator(irutils.Builder):
    """
      Generates intermediate (IR) code from a package. The entry function is
      'genModule'. The main task of this part is to rewrite complex control
      structures, such as while and for loops into simple conditional
      jump statements. Also complex conditional statements are simplified.
      Such as 'and' and 'or' statements are rewritten in conditional jumps.
      And structured datatypes are rewritten.

      Type checking is done in one run with code generation.
    """
    def __init__(self, diag):
        self.logger = logging.getLogger('c3cgen')
        self.diag = diag

    def gencode(self, pkg):
        self.prepare()
        assert type(pkg) is Package
        self.pkg = pkg
        self.intType = pkg.scope['int']
        self.boolType = pkg.scope['bool']
        self.logger.info('Generating ir-code for {}'.format(pkg.name))
        self.varMap = {}    # Maps variables to storage locations.
        self.funcMap = {}
        self.m = ir.Module(pkg.name)
        self.genModule(pkg)
        return self.m

    def error(self, msg, loc=None):
        self.pkg.ok = False
        self.diag.error(msg, loc)

    # inner helpers:
    def genModule(self, pkg):
        # Take care of forward declarations:
        try:
            for s in pkg.innerScope.Functions:
                f = self.newFunction(s.name)
                self.funcMap[s] = f
            for v in pkg.innerScope.Variables:
                self.varMap[v] = self.newTemp()
            for s in pkg.innerScope.Functions:
                self.genFunction(s)
        except SemanticError as e:
            self.error(e.msg, e.loc)

    def checkType(self, t):
        """ Verify the type is correct """
        t = self.theType(t)

    def genFunction(self, fn):
        # TODO: handle arguments
        f = self.funcMap[fn]
        f.return_value = self.newTemp()
        self.setFunction(f)
        l2 = self.newBlock()
        self.emit(ir.Jump(l2))
        self.setBlock(l2)
        # generate room for locals:

        for sym in fn.innerScope:
            # TODO: handle parameters different
            if sym.isParameter:
                self.checkType(sym.typ)
                v = ir.Parameter(sym.name)
                f.addParameter(v)
            elif sym.isLocal:
                self.checkType(sym.typ)
                v = ir.LocalVariable(sym.name)
                f.addLocal(v)
            elif isinstance(sym, Variable):
                self.checkType(sym.typ)
                v = ir.LocalVariable(sym.name)
                f.addLocal(v)
            else:
                #v = self.newTemp()
                raise NotImplementedError('{}'.format(sym))
            self.varMap[sym] = v

        self.genCode(fn.body)
        # Set the default return value to zero:
        # TBD: this may not be required?
        self.emit(ir.Move(f.return_value, ir.Const(0)))
        self.emit(ir.Jump(f.epiloog))
        self.setFunction(None)

    def genCode(self, code):
        try:
            self.genStmt(code)
        except SemanticError as e:
            self.error(e.msg, e.loc)

    def genStmt(self, code):
        assert isinstance(code, Statement)
        self.setLoc(code.loc)
        if type(code) is Compound:
            for s in code.statements:
                self.genCode(s)
        elif type(code) is Empty:
            pass
        elif type(code) is Assignment:
            lval = self.genExprCode(code.lval)
            rval = self.genExprCode(code.rval)
            if not self.equalTypes(code.lval.typ, code.rval.typ):
                msg = 'Cannot assign {} to {}'.format(code.lval.typ, code.rval.typ)
                raise SemanticError(msg, code.loc)
            if not code.lval.lvalue:
                raise SemanticError('No valid lvalue {}'.format(code.lval), code.lval.loc)
            self.emit(ir.Move(lval, rval))
        elif type(code) is ExpressionStatement:
            self.emit(ir.Exp(self.genExprCode(code.ex)))
        elif type(code) is If:
            bbtrue = self.newBlock()
            bbfalse = self.newBlock()
            te = self.newBlock()
            self.genCondCode(code.condition, bbtrue, bbfalse)
            self.setBlock(bbtrue)
            self.genCode(code.truestatement)
            self.emit(ir.Jump(te))
            self.setBlock(bbfalse)
            self.genCode(code.falsestatement)
            self.emit(ir.Jump(te))
            self.setBlock(te)
        elif type(code) is Return:
            re = self.genExprCode(code.expr)
            self.emit(ir.Move(self.fn.return_value, re))
            self.emit(ir.Jump(self.fn.epiloog))
            b = self.newBlock()
            self.setBlock(b)
        elif type(code) is While:
            bbdo = self.newBlock()
            bbtest = self.newBlock()
            te = self.newBlock()
            self.emit(ir.Jump(bbtest))
            self.setBlock(bbtest)
            self.genCondCode(code.condition, bbdo, te)
            self.setBlock(bbdo)
            self.genCode(code.statement)
            self.emit(ir.Jump(bbtest))
            self.setBlock(te)
        else:
            raise NotImplementedError('Unknown stmt {}'.format(code))

    def genCondCode(self, expr, bbtrue, bbfalse):
        # Implement sequential logical operators
        if type(expr) is Binop:
            if expr.op == 'or':
                l2 = self.newBlock()
                self.genCondCode(expr.a, bbtrue, l2)
                if not equalTypes(expr.a.typ, self.boolType):
                    raise SemanticError('Must be boolean', expr.a.loc)
                self.setBlock(l2)
                self.genCondCode(expr.b, bbtrue, bbfalse)
                if not equalTypes(expr.b.typ, self.boolType):
                    raise SemanticError('Must be boolean', expr.b.loc)
            elif expr.op == 'and':
                l2 = self.newBlock()
                self.genCondCode(expr.a, l2, bbfalse)
                if not equalTypes(expr.a.typ, self.boolType):
                    self.error('Must be boolean', expr.a.loc)
                self.setBlock(l2)
                self.genCondCode(expr.b, bbtrue, bbfalse)
                if not equalTypes(expr.b.typ, self.boolType):
                    raise SemanticError('Must be boolean', expr.b.loc)
            elif expr.op in ['==', '>', '<', '!=', '<=', '>=']:
                ta = self.genExprCode(expr.a)
                tb = self.genExprCode(expr.b)
                if not self.equalTypes(expr.a.typ, expr.b.typ):
                    raise SemanticError('Types unequal {} != {}'
                               .format(expr.a.typ, expr.b.typ), expr.loc)
                self.emit(ir.CJump(ta, expr.op, tb, bbtrue, bbfalse))
            else:
                raise NotImplementedError('Unknown condition {}'.format(expr))
            expr.typ = self.boolType
        elif type(expr) is Literal:
            if expr.val:
                self.emit(ir.Jump(bbtrue))
            else:
                self.emit(ir.Jump(bbfalse))
            expr.typ = self.boolType
        else:
            raise NotImplementedError('Unknown cond {}'.format(expr))
        if not self.equalTypes(expr.typ, self.boolType):
            self.error('Condition must be boolean', expr.loc)

    def genExprCode(self, expr):
        assert isinstance(expr, Expression)
        if type(expr) is Binop:
            expr.lvalue = False
            if expr.op in ['+', '-', '*', '/', '<<', '>>', '|', '&']:
                ra = self.genExprCode(expr.a)
                rb = self.genExprCode(expr.b)
                if self.equalTypes(expr.a.typ, self.intType) and \
                        self.equalTypes(expr.b.typ, self.intType):
                    expr.typ = expr.a.typ
                else:
                    # assume void here? TODO: throw exception!
                    raise SemanticError('Can only add integers', expr.loc)
            else:
                raise NotImplementedError("Cannot use equality as expressions")
            return ir.Binop(ra, expr.op, rb)
        elif type(expr) is Unop:
            if expr.op == '&':
                ra = self.genExprCode(expr.a)
                expr.typ = PointerType(expr.a.typ)
                if not expr.a.lvalue:
                    raise SemanticError('No valid lvalue', expr.a.loc)
                expr.lvalue = False
                assert type(ra) is ir.Mem
                return ra.e
            else:
                raise NotImplementedError('Unknown unop {0}'.format(expr.op))
        elif type(expr) is Identifier:
            # Generate code for this identifier.
            expr.lvalue = True
            tg = self.resolveSymbol(expr)
            expr.kind = type(tg)
            expr.typ = tg.typ
            # This returns the dereferenced variable.
            if type(tg) is Variable:
                # TODO: now parameters are handled different. Not nice?
                return ir.Mem(self.varMap[tg])
            else:
                return ir.Mem(self.varMap[tg])
        elif type(expr) is Deref:
            # dereference pointer type:
            addr = self.genExprCode(expr.ptr)
            ptr_typ = self.theType(expr.ptr.typ)
            expr.lvalue = True
            if type(ptr_typ) is PointerType:
                expr.typ = ptr_typ.ptype
                return ir.Mem(addr)
            else:
                raise SemanticError('Cannot deref non-pointer', expr.loc)
                expr.typ = self.intType
                return ir.Mem(ir.Const(0))
        elif type(expr) is Member:
            base = self.genExprCode(expr.base)
            expr.lvalue = expr.base.lvalue
            basetype = self.theType(expr.base.typ)
            if type(basetype) is StructureType:
                if basetype.hasField(expr.field):
                    expr.typ = basetype.fieldType(expr.field)
                else:
                    raise SemanticError('{} does not contain field {}'
                               .format(basetype, expr.field), expr.loc)
            else:
                raise SemanticError('Cannot select field {} of non-structure type {}'
                           .format(sym.field, basetype), sym.loc)

            assert type(base) is ir.Mem, type(base)
            base = base.e
            bt = self.theType(expr.base.typ)
            offset = ir.Const(bt.fieldOffset(expr.field))
            return ir.Mem(ir.Add(base, offset))
        elif type(expr) is Literal:
            expr.lvalue = False
            typemap = {int: 'int', float: 'double', bool: 'bool'}
            if type(expr.val) in typemap:
                expr.typ = self.pkg.scope[typemap[type(expr.val)]]
            else:
                raise SemanticError('Unknown literal type {}'.format(expr.val))
            return ir.Const(expr.val)
        elif type(expr) is TypeCast:
            # TODO: improve this mess:
            ar = self.genExprCode(expr.a)
            ft = self.theType(expr.a.typ)
            tt = self.theType(expr.to_type)
            if isinstance(ft, PointerType) and isinstance(tt, PointerType):
                expr.typ = expr.to_type
                return ar
            elif type(ft) is BaseType and ft.name == 'int' and \
                    isinstance(tt, PointerType):
                expr.typ = expr.to_type
                return ar
            else:
                self.error('Cannot cast {} to {}'
                           .format(ft, tt), expr.loc)
                expr.typ = self.intType
                return ir.Const(0)
        elif type(expr) is FunctionCall:
            # Evaluate the arguments:
            args = [self.genExprCode(e) for e in expr.args]
            # Check arguments:
            if type(expr.proc) is Identifier:
                tg = self.resolveSymbol(expr.proc)
            else:
                raise Exception()
            assert type(tg) is Function
            ftyp = tg.typ
            fname = tg.name
            ptypes = ftyp.parametertypes
            if len(expr.args) != len(ptypes):
                raise SemanticError('Function {2}: {0} arguments required, {1} given'
                           .format(len(ptypes), len(expr.args), fname), expr.loc)
            for arg, at in zip(expr.args, ptypes):
                if not self.equalTypes(arg.typ, at):
                    raise SemanticError('Got {0}, expected {1}'
                               .format(arg.typ, at), arg.loc)
            # determine return type:
            expr.typ = ftyp.returntype
            return ir.Call(fname, args)
        else:
            raise NotImplementedError('Unknown expr {}'.format(expr))

    def resolveSymbol(self, sym):
        assert type(sym) in [Identifier, Member]
        if type(sym) is Member:
            base = self.resolveSymbol(sym.base)
            scope = base.innerScope
            name = sym.field
        elif type(sym) is Identifier:
            scope = sym.scope
            name = sym.target
        else:
            raise NotImplementedError(str(sym))
        if name in scope:
            s = scope[name]
        else:
            raise SemanticError('{} undefined'.format(name), sym.loc)
        assert isinstance(s, Symbol)
        return s

    def theType(self, t):
        """ Recurse until a 'real' type is found """
        if type(t) is DefinedType:
            t = self.theType(t.typ)
        elif type(t) is Identifier:
            t = self.theType(self.resolveSymbol(t))
        elif type(t) is Member:
            # Possible when using types of modules:
            t = self.theType(self.resolveSymbol(t))
        elif isinstance(t, Type):
            pass
        else:
            raise NotImplementedError(str(t))
        assert isinstance(t, Type)
        return t

    def equalTypes(self, a, b):
        """ Compare types a and b for structural equavalence. """
        # Recurse into named types:
        a = self.theType(a)
        b = self.theType(b)
        assert isinstance(a, Type)
        assert isinstance(b, Type)

        if type(a) is type(b):
            if type(a) is BaseType:
                return a.name == b.name
            elif type(a) is PointerType:
                return self.equalTypes(a.ptype, b.ptype)
            elif type(a) is StructureType:
                if len(a.mems) != len(b.mems):
                    return False
                return all(self.equalTypes(am.typ, bm.typ) for am, bm in
                           zip(a.mems, b.mems))
            else:
                raise NotImplementedError('{} not implemented'.format(type(a)))
        return False