view python/ppci/c3/codegenerator.py @ 347:742588fb8cd6 devel

Merge into devel branch
author Windel Bouwman
date Fri, 07 Mar 2014 17:10:21 +0100
parents d1ecc493384e
children b8ad45b3a573
line wrap: on
line source

import logging
from .. import ir
from .. import irutils
from . import astnodes as ast


class SemanticError(Exception):
    """ Error thrown when a semantic issue is observed """
    def __init__(self, msg, loc):
        super().__init__()
        self.msg = msg
        self.loc = loc


class CodeGenerator(irutils.Builder):
    """
      Generates intermediate (IR) code from a package. The entry function is
      'genModule'. The main task of this part is to rewrite complex control
      structures, such as while and for loops into simple conditional
      jump statements. Also complex conditional statements are simplified.
      Such as 'and' and 'or' statements are rewritten in conditional jumps.
      And structured datatypes are rewritten.

      Type checking is done in one run with code generation.
    """
    def __init__(self, diag):
        self.logger = logging.getLogger('c3cgen')
        self.diag = diag

    def gencode(self, pkg):
        """ Generate code for a single module """
        self.prepare()
        assert type(pkg) is ast.Package
        self.pkg = pkg
        self.intType = pkg.scope['int']
        self.boolType = pkg.scope['bool']
        self.logger.debug('Generating ir-code for {}'.format(pkg.name), extra={'c3_ast':pkg})
        self.varMap = {}    # Maps variables to storage locations.
        self.funcMap = {}
        self.m = ir.Module(pkg.name)
        try:
            for s in pkg.innerScope.Functions:
                f = self.newFunction(s.name)
                self.funcMap[s] = f
            for v in pkg.innerScope.Variables:
                self.varMap[v] = self.newTemp()
            for s in pkg.innerScope.Functions:
                self.gen_function(s)
        except SemanticError as e:
            self.error(e.msg, e.loc)
        if self.pkg.ok:
            return self.m

    def error(self, msg, loc=None):
        self.pkg.ok = False
        self.diag.error(msg, loc)

    def gen_function(self, fn):
        # TODO: handle arguments
        f = self.funcMap[fn]
        f.return_value = self.newTemp()
        self.setFunction(f)
        l2 = self.newBlock()
        self.emit(ir.Jump(l2))
        self.setBlock(l2)
        # generate room for locals:

        for sym in fn.innerScope:
            self.the_type(sym.typ)
            if sym.isParameter:
                p = ir.Parameter(sym.name)
                variable = ir.LocalVariable(sym.name + '_copy')
                f.addParameter(p)
                f.addLocal(variable)
                # Move parameter into local copy:
                self.emit(ir.Move(ir.Mem(variable), p))
            elif sym.isLocal:
                variable = ir.LocalVariable(sym.name)
                f.addLocal(variable)
            elif isinstance(sym, ast.Variable):
                variable = ir.LocalVariable(sym.name)
                f.addLocal(variable)
            else:
                raise NotImplementedError('{}'.format(sym))
            self.varMap[sym] = variable

        self.genCode(fn.body)
        self.emit(ir.Move(f.return_value, ir.Const(0)))
        self.emit(ir.Jump(f.epiloog))
        self.setFunction(None)

    def genCode(self, code):
        """ Wrapper around gen_stmt to catch errors """
        try:
            self.gen_stmt(code)
        except SemanticError as e:
            self.error(e.msg, e.loc)

    def gen_stmt(self, code):
        """ Generate code for a statement """
        assert isinstance(code, ast.Statement)
        self.setLoc(code.loc)
        if type(code) is ast.Compound:
            for s in code.statements:
                self.genCode(s)
        elif type(code) is ast.Empty:
            pass
        elif type(code) is ast.Assignment:
            lval = self.genExprCode(code.lval)
            rval = self.genExprCode(code.rval)
            if not self.equalTypes(code.lval.typ, code.rval.typ):
                msg = 'Cannot assign {} to {}'.format(code.rval.typ, code.lval.typ)
                raise SemanticError(msg, code.loc)
            if not code.lval.lvalue:
                raise SemanticError('No valid lvalue {}'.format(code.lval), code.lval.loc)
            self.emit(ir.Move(lval, rval))
        elif type(code) is ast.ExpressionStatement:
            self.emit(ir.Exp(self.genExprCode(code.ex)))
        elif type(code) is ast.If:
            bbtrue = self.newBlock()
            bbfalse = self.newBlock()
            te = self.newBlock()
            self.gen_cond_code(code.condition, bbtrue, bbfalse)
            self.setBlock(bbtrue)
            self.genCode(code.truestatement)
            self.emit(ir.Jump(te))
            self.setBlock(bbfalse)
            self.genCode(code.falsestatement)
            self.emit(ir.Jump(te))
            self.setBlock(te)
        elif type(code) is ast.Return:
            re = self.genExprCode(code.expr)
            self.emit(ir.Move(self.fn.return_value, re))
            self.emit(ir.Jump(self.fn.epiloog))
            b = self.newBlock()
            self.setBlock(b)
        elif type(code) is ast.While:
            bbdo = self.newBlock()
            bbtest = self.newBlock()
            te = self.newBlock()
            self.emit(ir.Jump(bbtest))
            self.setBlock(bbtest)
            self.gen_cond_code(code.condition, bbdo, te)
            self.setBlock(bbdo)
            self.genCode(code.statement)
            self.emit(ir.Jump(bbtest))
            self.setBlock(te)
        elif type(code) is ast.For:
            bbdo = self.newBlock()
            bbtest = self.newBlock()
            te = self.newBlock()
            self.genCode(code.init)
            self.emit(ir.Jump(bbtest))
            self.setBlock(bbtest)
            self.gen_cond_code(code.condition, bbdo, te)
            self.setBlock(bbdo)
            self.genCode(code.statement)
            self.emit(ir.Jump(bbtest))
            self.setBlock(te)
        else:
            raise NotImplementedError('Unknown stmt {}'.format(code))

    def gen_cond_code(self, expr, bbtrue, bbfalse):
        """ Generate conditional logic.
            Implement sequential logical operators. """
        if type(expr) is ast.Binop:
            if expr.op == 'or':
                l2 = self.newBlock()
                self.gen_cond_code(expr.a, bbtrue, l2)
                if not self.equalTypes(expr.a.typ, self.boolType):
                    raise SemanticError('Must be boolean', expr.a.loc)
                self.setBlock(l2)
                self.gen_cond_code(expr.b, bbtrue, bbfalse)
                if not self.equalTypes(expr.b.typ, self.boolType):
                    raise SemanticError('Must be boolean', expr.b.loc)
            elif expr.op == 'and':
                l2 = self.newBlock()
                self.gen_cond_code(expr.a, l2, bbfalse)
                if not self.equalTypes(expr.a.typ, self.boolType):
                    self.error('Must be boolean', expr.a.loc)
                self.setBlock(l2)
                self.gen_cond_code(expr.b, bbtrue, bbfalse)
                if not self.equalTypes(expr.b.typ, self.boolType):
                    raise SemanticError('Must be boolean', expr.b.loc)
            elif expr.op in ['==', '>', '<', '!=', '<=', '>=']:
                ta = self.genExprCode(expr.a)
                tb = self.genExprCode(expr.b)
                if not self.equalTypes(expr.a.typ, expr.b.typ):
                    raise SemanticError('Types unequal {} != {}'
                               .format(expr.a.typ, expr.b.typ), expr.loc)
                self.emit(ir.CJump(ta, expr.op, tb, bbtrue, bbfalse))
            else:
                raise SemanticError('non-bool: {}'.format(expr.op), expr.loc)
            expr.typ = self.boolType
        elif type(expr) is ast.Literal:
            self.genExprCode(expr)
            if expr.val:
                self.emit(ir.Jump(bbtrue))
            else:
                self.emit(ir.Jump(bbfalse))
        else:
            raise NotImplementedError('Unknown cond {}'.format(expr))
        if not self.equalTypes(expr.typ, self.boolType):
            self.error('Condition must be boolean', expr.loc)

    def genExprCode(self, expr):
        """ Generate code for an expression. Return the generated ir-value """
        assert isinstance(expr, ast.Expression)
        if type(expr) is ast.Binop:
            expr.lvalue = False
            if expr.op in ['+', '-', '*', '/', '<<', '>>', '|', '&']:
                ra = self.genExprCode(expr.a)
                rb = self.genExprCode(expr.b)
                if self.equalTypes(expr.a.typ, self.intType) and \
                        self.equalTypes(expr.b.typ, self.intType):
                    expr.typ = expr.a.typ
                else:
                    raise SemanticError('Can only add integers', expr.loc)
            else:
                raise NotImplementedError("Cannot use equality as expressions")
            return ir.Binop(ra, expr.op, rb)
        elif type(expr) is ast.Unop:
            if expr.op == '&':
                ra = self.genExprCode(expr.a)
                expr.typ = ast.PointerType(expr.a.typ)
                if not expr.a.lvalue:
                    raise SemanticError('No valid lvalue', expr.a.loc)
                expr.lvalue = False
                assert type(ra) is ir.Mem
                return ra.e
            else:
                raise NotImplementedError('Unknown unop {0}'.format(expr.op))
        elif type(expr) is ast.Identifier:
            # Generate code for this identifier.
            tg = self.resolveSymbol(expr)
            expr.kind = type(tg)
            expr.typ = tg.typ
            # This returns the dereferenced variable.
            if isinstance(tg, ast.Variable):
                expr.lvalue = True
                return ir.Mem(self.varMap[tg])
            elif isinstance(tg, ast.Constant):
                c_val = self.genExprCode(tg.value)
                return self.evalConst(c_val)
            else:
                raise NotImplementedError(str(tg))
        elif type(expr) is ast.Deref:
            # dereference pointer type:
            addr = self.genExprCode(expr.ptr)
            ptr_typ = self.the_type(expr.ptr.typ)
            expr.lvalue = True
            if type(ptr_typ) is ast.PointerType:
                expr.typ = ptr_typ.ptype
                return ir.Mem(addr)
            else:
                raise SemanticError('Cannot deref non-pointer', expr.loc)
        elif type(expr) is ast.Member:
            base = self.genExprCode(expr.base)
            expr.lvalue = expr.base.lvalue
            basetype = self.the_type(expr.base.typ)
            if type(basetype) is ast.StructureType:
                if basetype.hasField(expr.field):
                    expr.typ = basetype.fieldType(expr.field)
                else:
                    raise SemanticError('{} does not contain field {}'
                                        .format(basetype, expr.field), expr.loc)
            else:
                raise SemanticError('Cannot select {} of non-structure type {}'
                                    .format(expr.field, basetype), expr.loc)

            assert type(base) is ir.Mem, type(base)
            bt = self.the_type(expr.base.typ)
            offset = ir.Const(bt.fieldOffset(expr.field))
            return ir.Mem(ir.Add(base.e, offset))
        elif type(expr) is ast.Literal:
            expr.lvalue = False
            typemap = {int: 'int', float: 'double', bool: 'bool', str:'string'}
            if type(expr.val) in typemap:
                expr.typ = self.pkg.scope[typemap[type(expr.val)]]
            else:
                raise SemanticError('Unknown literal type {}'.format(expr.val), expr.loc)
            return ir.Const(expr.val)
        elif type(expr) is ast.TypeCast:
            return self.gen_type_cast(expr)
        elif type(expr) is ast.FunctionCall:
            return self.gen_function_call(expr)
        else:
            raise NotImplementedError('Unknown expr {}'.format(expr))

    def gen_type_cast(self, expr):
        """ Generate code for type casting """
        ar = self.genExprCode(expr.a)
        from_type = self.the_type(expr.a.typ)
        to_type = self.the_type(expr.to_type)
        if isinstance(from_type, ast.PointerType) and isinstance(to_type, ast.PointerType):
            expr.typ = expr.to_type
            return ar
        elif type(from_type) is ast.BaseType and from_type.name == 'int' and \
                isinstance(to_type, ast.PointerType):
            expr.typ = expr.to_type
            return ar
        else:
            raise SemanticError('Cannot cast {} to {}'
                                .format(from_type, to_type), expr.loc)
 
    def gen_function_call(self, expr):
        """ Generate code for a function call """
        # Evaluate the arguments:
        args = [self.genExprCode(e) for e in expr.args]
        # Check arguments:
        tg = self.resolveSymbol(expr.proc)
        if type(tg) is not ast.Function:
            raise SemanticError('cannot call {}'.format(tg))
        ftyp = tg.typ
        fname = tg.package.name + '_' + tg.name
        ptypes = ftyp.parametertypes
        if len(expr.args) != len(ptypes):
            raise SemanticError('{} requires {} arguments, {} given'
                       .format(fname, len(ptypes), len(expr.args)), expr.loc)
        for arg, at in zip(expr.args, ptypes):
            if not self.equalTypes(arg.typ, at):
                raise SemanticError('Got {}, expected {}'
                           .format(arg.typ, at), arg.loc)
        # determine return type:
        expr.typ = ftyp.returntype
        return ir.Call(fname, args)

    def evalConst(self, c):
        if isinstance(c, ir.Const):
            return c
        else:
            raise SemanticError('Cannot evaluate constant {}'.format(c))

    def resolveSymbol(self, sym):
        if type(sym) is ast.Member:
            base = self.resolveSymbol(sym.base)
            if type(base) is not ast.Package:
                raise SemanticError('Base is not a package', sym.loc)
            scope = base.innerScope
            name = sym.field
        elif type(sym) is ast.Identifier:
            scope = sym.scope
            name = sym.target
        else:
            raise NotImplementedError(str(sym))
        if name in scope:
            s = scope[name]
        else:
            raise SemanticError('{} undefined'.format(name), sym.loc)
        assert isinstance(s, ast.Symbol)
        return s

    def size_of(self, t):
        """ Determine the byte size of a type """
        t = self.the_type(t)
        if type(t) is ast.BaseType:
            return t.bytesize
        elif type(t) is ast.StructureType:
            return sum(self.size_of(mem.typ) for mem in t.mems)
        else:
            raise NotImplementedError(str(t))

    def the_type(self, t):
        """ Recurse until a 'real' type is found """
        if type(t) is ast.DefinedType:
            t = self.the_type(t.typ)
        elif type(t) in [ast.Identifier, ast.Member]:
            t = self.the_type(self.resolveSymbol(t))
        elif type(t) is ast.StructureType:
            # Setup offsets of fields. Is this the right place?:
            offset = 0
            for mem in t.mems:
                mem.offset = offset
                offset = offset + self.size_of(mem.typ)
        elif isinstance(t, ast.Type):
            pass
        else:
            raise NotImplementedError(str(t))
        assert isinstance(t, ast.Type)
        return t

    def equalTypes(self, a, b):
        """ Compare types a and b for structural equavalence. """
        # Recurse into named types:
        a = self.the_type(a)
        b = self.the_type(b)

        if type(a) is type(b):
            if type(a) is ast.BaseType:
                return a.name == b.name
            elif type(a) is ast.PointerType:
                return self.equalTypes(a.ptype, b.ptype)
            elif type(a) is ast.StructureType:
                if len(a.mems) != len(b.mems):
                    return False
                return all(self.equalTypes(am.typ, bm.typ) for am, bm in
                           zip(a.mems, b.mems))
            else:
                raise NotImplementedError('{} not implemented'.format(type(a)))
        return False