view python/ppci/c3/codegenerator.py @ 389:2ec730e45ea1

Added check for recursive struct
author Windel Bouwman
date Fri, 16 May 2014 12:29:31 +0200
parents c49459768aaa
children 6ae782a085e0
line wrap: on
line source

import logging
import struct
from .. import ir
from .. import irutils
from . import astnodes as ast


class SemanticError(Exception):
    """ Error thrown when a semantic issue is observed """
    def __init__(self, msg, loc):
        super().__init__()
        self.msg = msg
        self.loc = loc


class CodeGenerator(irutils.Builder):
    """
      Generates intermediate (IR) code from a package. The entry function is
      'genModule'. The main task of this part is to rewrite complex control
      structures, such as while and for loops into simple conditional
      jump statements. Also complex conditional statements are simplified.
      Such as 'and' and 'or' statements are rewritten in conditional jumps.
      And structured datatypes are rewritten.

      Type checking is done in one run with code generation.
    """
    def __init__(self, diag):
        self.logger = logging.getLogger('c3cgen')
        self.diag = diag

    def gencode(self, pkg):
        """ Generate code for a single module """
        self.prepare()
        assert type(pkg) is ast.Package
        self.pkg = pkg
        self.intType = pkg.scope['int']
        self.boolType = pkg.scope['bool']
        self.pointerSize = 4
        self.logger.debug('Generating ir-code for {}'.format(pkg.name),
                          extra={'c3_ast': pkg})
        self.varMap = {}    # Maps variables to storage locations.
        self.m = ir.Module(pkg.name)
        try:
            for typ in pkg.Types:
                self.check_type(typ)
            # Only generate function if function contains a body:
            real_functions = list(filter(
                lambda f: f.body, pkg.innerScope.Functions))
            for v in pkg.innerScope.Variables:
                v2 = ir.GlobalVariable(v.name)
                self.varMap[v] = v2
                if not v.isLocal:
                    self.m.add_variable(v2)
            for s in real_functions:
                self.gen_function(s)
        except SemanticError as e:
            self.error(e.msg, e.loc)
        if self.pkg.ok:
            return self.m

    def error(self, msg, loc=None):
        self.pkg.ok = False
        self.diag.error(msg, loc)

    def gen_function(self, fn):
        # TODO: handle arguments
        f = self.newFunction(fn.name)
        f.return_value = self.newTemp()
        self.setFunction(f)
        l2 = self.newBlock()
        self.emit(ir.Jump(l2))
        self.setBlock(l2)
        # generate room for locals:

        for sym in fn.innerScope:
            self.check_type(sym.typ)
            if sym.isParameter:
                p = ir.Parameter(sym.name)
                variable = ir.LocalVariable(sym.name + '_copy')
                f.addParameter(p)
                f.addLocal(variable)
                # Move parameter into local copy:
                self.emit(ir.Move(ir.Mem(variable), p))
            elif sym.isLocal:
                variable = ir.LocalVariable(sym.name)
                f.addLocal(variable)
            elif isinstance(sym, ast.Variable):
                variable = ir.LocalVariable(sym.name)
                f.addLocal(variable)
            else:
                raise NotImplementedError('{}'.format(sym))
            self.varMap[sym] = variable

        self.genCode(fn.body)
        self.emit(ir.Move(f.return_value, ir.Const(0)))
        self.emit(ir.Jump(f.epiloog))
        self.setFunction(None)

    def genCode(self, code):
        """ Wrapper around gen_stmt to catch errors """
        try:
            self.gen_stmt(code)
        except SemanticError as e:
            self.error(e.msg, e.loc)

    def gen_stmt(self, code):
        """ Generate code for a statement """
        assert isinstance(code, ast.Statement)
        self.setLoc(code.loc)
        if type(code) is ast.Compound:
            for s in code.statements:
                self.genCode(s)
        elif type(code) is ast.Empty:
            pass
        elif type(code) is ast.Assignment:
            lval = self.genExprCode(code.lval)
            rval = self.genExprCode(code.rval)
            if not self.equal_types(code.lval.typ, code.rval.typ):
                raise SemanticError('Cannot assign {} to {}'
                                    .format(code.rval.typ, code.lval.typ),
                                    code.loc)
            if not code.lval.lvalue:
                raise SemanticError('No valid lvalue {}'.format(code.lval),
                                    code.lval.loc)
            self.emit(ir.Move(lval, rval))
        elif type(code) is ast.ExpressionStatement:
            self.emit(ir.Exp(self.genExprCode(code.ex)))
        elif type(code) is ast.If:
            bbtrue = self.newBlock()
            bbfalse = self.newBlock()
            te = self.newBlock()
            self.gen_cond_code(code.condition, bbtrue, bbfalse)
            self.setBlock(bbtrue)
            self.genCode(code.truestatement)
            self.emit(ir.Jump(te))
            self.setBlock(bbfalse)
            self.genCode(code.falsestatement)
            self.emit(ir.Jump(te))
            self.setBlock(te)
        elif type(code) is ast.Return:
            re = self.genExprCode(code.expr)
            self.emit(ir.Move(self.fn.return_value, re))
            self.emit(ir.Jump(self.fn.epiloog))
            b = self.newBlock()
            self.setBlock(b)
        elif type(code) is ast.While:
            bbdo = self.newBlock()
            bbtest = self.newBlock()
            te = self.newBlock()
            self.emit(ir.Jump(bbtest))
            self.setBlock(bbtest)
            self.gen_cond_code(code.condition, bbdo, te)
            self.setBlock(bbdo)
            self.genCode(code.statement)
            self.emit(ir.Jump(bbtest))
            self.setBlock(te)
        elif type(code) is ast.For:
            bbdo = self.newBlock()
            bbtest = self.newBlock()
            te = self.newBlock()
            self.genCode(code.init)
            self.emit(ir.Jump(bbtest))
            self.setBlock(bbtest)
            self.gen_cond_code(code.condition, bbdo, te)
            self.setBlock(bbdo)
            self.genCode(code.statement)
            self.genCode(code.final)
            self.emit(ir.Jump(bbtest))
            self.setBlock(te)
        else:
            raise NotImplementedError('Unknown stmt {}'.format(code))

    def gen_cond_code(self, expr, bbtrue, bbfalse):
        """ Generate conditional logic.
            Implement sequential logical operators. """
        if type(expr) is ast.Binop:
            if expr.op == 'or':
                l2 = self.newBlock()
                self.gen_cond_code(expr.a, bbtrue, l2)
                if not self.equal_types(expr.a.typ, self.boolType):
                    raise SemanticError('Must be boolean', expr.a.loc)
                self.setBlock(l2)
                self.gen_cond_code(expr.b, bbtrue, bbfalse)
                if not self.equal_types(expr.b.typ, self.boolType):
                    raise SemanticError('Must be boolean', expr.b.loc)
            elif expr.op == 'and':
                l2 = self.newBlock()
                self.gen_cond_code(expr.a, l2, bbfalse)
                if not self.equal_types(expr.a.typ, self.boolType):
                    self.error('Must be boolean', expr.a.loc)
                self.setBlock(l2)
                self.gen_cond_code(expr.b, bbtrue, bbfalse)
                if not self.equal_types(expr.b.typ, self.boolType):
                    raise SemanticError('Must be boolean', expr.b.loc)
            elif expr.op in ['==', '>', '<', '!=', '<=', '>=']:
                ta = self.genExprCode(expr.a)
                tb = self.genExprCode(expr.b)
                if not self.equal_types(expr.a.typ, expr.b.typ):
                    raise SemanticError('Types unequal {} != {}'
                                        .format(expr.a.typ, expr.b.typ),
                                        expr.loc)
                self.emit(ir.CJump(ta, expr.op, tb, bbtrue, bbfalse))
            else:
                raise SemanticError('non-bool: {}'.format(expr.op), expr.loc)
            expr.typ = self.boolType
        elif type(expr) is ast.Literal:
            self.genExprCode(expr)
            if expr.val:
                self.emit(ir.Jump(bbtrue))
            else:
                self.emit(ir.Jump(bbfalse))
        else:
            raise NotImplementedError('Unknown cond {}'.format(expr))

        # Check that the condition is a boolean value:
        if not self.equal_types(expr.typ, self.boolType):
            self.error('Condition must be boolean', expr.loc)

    def genExprCode(self, expr):
        """ Generate code for an expression. Return the generated ir-value """
        assert isinstance(expr, ast.Expression)
        if type(expr) is ast.Binop:
            expr.lvalue = False
            if expr.op in ['+', '-', '*', '/', '<<', '>>', '|', '&']:
                ra = self.genExprCode(expr.a)
                rb = self.genExprCode(expr.b)
                if self.equal_types(expr.a.typ, self.intType) and \
                        self.equal_types(expr.b.typ, self.intType):
                    expr.typ = expr.a.typ
                else:
                    raise SemanticError('Can only add integers', expr.loc)
            else:
                raise NotImplementedError("Cannot use equality as expressions")
            return ir.Binop(ra, expr.op, rb)
        elif type(expr) is ast.Unop:
            if expr.op == '&':
                ra = self.genExprCode(expr.a)
                expr.typ = ast.PointerType(expr.a.typ)
                if not expr.a.lvalue:
                    raise SemanticError('No valid lvalue', expr.a.loc)
                expr.lvalue = False
                assert type(ra) is ir.Mem
                return ra.e
            else:
                raise NotImplementedError('Unknown unop {0}'.format(expr.op))
        elif type(expr) is ast.Identifier:
            # Generate code for this identifier.
            tg = self.resolveSymbol(expr)
            expr.kind = type(tg)
            expr.typ = tg.typ
            # This returns the dereferenced variable.
            if isinstance(tg, ast.Variable):
                expr.lvalue = True
                return ir.Mem(self.varMap[tg])
            elif isinstance(tg, ast.Constant):
                c_val = self.genExprCode(tg.value)
                return self.evalConst(c_val)
            else:
                raise NotImplementedError(str(tg))
        elif type(expr) is ast.Deref:
            # dereference pointer type:
            addr = self.genExprCode(expr.ptr)
            ptr_typ = self.the_type(expr.ptr.typ)
            expr.lvalue = True
            if type(ptr_typ) is ast.PointerType:
                expr.typ = ptr_typ.ptype
                return ir.Mem(addr)
            else:
                raise SemanticError('Cannot deref non-pointer', expr.loc)
        elif type(expr) is ast.Member:
            base = self.genExprCode(expr.base)
            expr.lvalue = expr.base.lvalue
            basetype = self.the_type(expr.base.typ)
            if type(basetype) is ast.StructureType:
                if basetype.hasField(expr.field):
                    expr.typ = basetype.fieldType(expr.field)
                else:
                    raise SemanticError('{} does not contain field {}'
                                        .format(basetype, expr.field),
                                        expr.loc)
            else:
                raise SemanticError('Cannot select {} of non-structure type {}'
                                    .format(expr.field, basetype), expr.loc)

            assert type(base) is ir.Mem, type(base)
            bt = self.the_type(expr.base.typ)
            offset = ir.Const(bt.fieldOffset(expr.field))
            return ir.Mem(ir.Add(base.e, offset))
        elif type(expr) is ast.Index:
            """ Array indexing """
            base = self.genExprCode(expr.base)
            idx = self.genExprCode(expr.i)
            base_typ = self.the_type(expr.base.typ)
            if not isinstance(base_typ, ast.ArrayType):
                raise SemanticError('Cannot index non-array type {}'
                                    .format(base_typ),
                                    expr.base.loc)
            idx_type = self.the_type(expr.i.typ)
            if not self.equal_types(idx_type, self.intType):
                raise SemanticError('Index must be int not {}'
                                    .format(idx_type), expr.i.loc)
            assert type(base) is ir.Mem
            element_type = self.the_type(base_typ.element_type)
            element_size = self.size_of(element_type)
            expr.typ = base_typ.element_type
            expr.lvalue = True

            return ir.Mem(ir.Add(base.e, ir.Mul(idx, ir.Const(element_size))))
        elif type(expr) is ast.Literal:
            expr.lvalue = False
            typemap = {int: 'int',
                       float: 'double',
                       bool: 'bool',
                       str: 'string'}
            if type(expr.val) in typemap:
                expr.typ = self.pkg.scope[typemap[type(expr.val)]]
            else:
                raise SemanticError('Unknown literal type {}'
                                    .format(expr.val), expr.loc)
            # Construct correct const value:
            if type(expr.val) is str:
                cval = self.pack_string(expr.val)
                return ir.Addr(ir.Const(cval))
            else:
                return ir.Const(expr.val)
        elif type(expr) is ast.TypeCast:
            return self.gen_type_cast(expr)
        elif type(expr) is ast.FunctionCall:
            return self.gen_function_call(expr)
        else:
            raise NotImplementedError('Unknown expr {}'.format(expr))

    def pack_string(self, txt):
        """ Pack a string using 4 bytes length followed by text data """
        length = struct.pack('<I', len(txt))
        data = txt.encode('ascii')
        return length + data

    def gen_type_cast(self, expr):
        """ Generate code for type casting """
        ar = self.genExprCode(expr.a)
        from_type = self.the_type(expr.a.typ)
        to_type = self.the_type(expr.to_type)
        if isinstance(from_type, ast.PointerType) and \
                isinstance(to_type, ast.PointerType):
            expr.typ = expr.to_type
            return ar
        elif self.equal_types(self.intType, from_type) and \
                isinstance(to_type, ast.PointerType):
            expr.typ = expr.to_type
            return ar
        elif self.equal_types(self.intType, to_type) \
                and isinstance(from_type, ast.PointerType):
            expr.typ = expr.to_type
            return ar
        elif type(from_type) is ast.BaseType and from_type.name == 'byte' and \
                type(to_type) is ast.BaseType and to_type.name == 'int':
            expr.typ = expr.to_type
            return ar
        else:
            raise SemanticError('Cannot cast {} to {}'
                                .format(from_type, to_type), expr.loc)

    def gen_function_call(self, expr):
        """ Generate code for a function call """
        # Evaluate the arguments:
        args = [self.genExprCode(e) for e in expr.args]
        # Check arguments:
        tg = self.resolveSymbol(expr.proc)
        if type(tg) is not ast.Function:
            raise SemanticError('cannot call {}'.format(tg))
        ftyp = tg.typ
        fname = tg.package.name + '_' + tg.name
        ptypes = ftyp.parametertypes
        if len(expr.args) != len(ptypes):
            raise SemanticError('{} requires {} arguments, {} given'
                                .format(fname, len(ptypes), len(expr.args)),
                                expr.loc)
        for arg, at in zip(expr.args, ptypes):
            if not self.equal_types(arg.typ, at):
                raise SemanticError('Got {}, expected {}'
                                    .format(arg.typ, at), arg.loc)
        # determine return type:
        expr.typ = ftyp.returntype
        return ir.Call(fname, args)

    def evalConst(self, c):
        if isinstance(c, ir.Const):
            return c
        else:
            raise SemanticError('Cannot evaluate constant {}'.format(c))

    def resolveSymbol(self, sym):
        if type(sym) is ast.Member:
            base = self.resolveSymbol(sym.base)
            if type(base) is not ast.Package:
                raise SemanticError('Base is not a package', sym.loc)
            scope = base.innerScope
            name = sym.field
        elif type(sym) is ast.Identifier:
            scope = sym.scope
            name = sym.target
        else:
            raise NotImplementedError(str(sym))
        if name in scope:
            s = scope[name]
        else:
            raise SemanticError('{} undefined'.format(name), sym.loc)
        assert isinstance(s, ast.Symbol)
        return s

    def size_of(self, t):
        """ Determine the byte size of a type """
        t = self.the_type(t)
        if type(t) is ast.BaseType:
            return t.bytesize
        elif type(t) is ast.StructureType:
            return sum(self.size_of(mem.typ) for mem in t.mems)
        elif type(t) is ast.ArrayType:
            return t.size * self.size_of(t.element_type)
        elif type(t) is ast.PointerType:
            return self.pointerSize
        else:
            raise NotImplementedError(str(t))

    def the_type(self, t, reveil_defined=True):
        """ Recurse until a 'real' type is found
            When reveil_defined is True, defined types are resolved to
            their backing types.
        """
        if type(t) is ast.DefinedType:
            if reveil_defined:
                t = self.the_type(t.typ)
        elif type(t) in [ast.Identifier, ast.Member]:
            t = self.the_type(self.resolveSymbol(t), reveil_defined)
        elif isinstance(t, ast.Type):
            pass
        else:
            raise NotImplementedError(str(t))
        assert isinstance(t, ast.Type)
        return t

    def equal_types(self, a, b, byname=False):
        """ Compare types a and b for structural equavalence.
            if byname is True stop on defined types.
        """
        # Recurse into named types:
        a = self.the_type(a, not byname)
        b = self.the_type(b, not byname)

        # Check types for sanity:
        self.check_type(a)
        self.check_type(b)

        # Do structural equivalence check:
        if type(a) is type(b):
            if type(a) is ast.BaseType:
                return a.name == b.name
            elif type(a) is ast.PointerType:
                # If a pointed type is detected, stop structural
                # equivalence:
                return self.equal_types(a.ptype, b.ptype, byname=True)
            elif type(a) is ast.StructureType:
                if len(a.mems) != len(b.mems):
                    return False
                return all(self.equal_types(am.typ, bm.typ) for am, bm in
                           zip(a.mems, b.mems))
            elif type(a) is ast.ArrayType:
                return self.equal_types(a.element_type, b.element_type)
            elif type(a) is ast.DefinedType:
                # Try by name in case of defined types:
                return a.name == b.name
            else:
                raise NotImplementedError('{} not implemented'.format(type(a)))
        return False

    def check_type(self, t, first=True, byname=False):
        """ Determine struct offsets and check for recursiveness by using
            mark and sweep algorithm.
            The calling function could call this function with first set
            to clear the marks.
        """

        # Reset the mark and sweep:
        if first:
            self.got_types = set()

        # Resolve the type:
        t = self.the_type(t, not byname)

        # Check for recursion:
        if t in self.got_types:
            raise SemanticError('Recursive data type {}'.format(t), None)

        if type(t) is ast.BaseType:
            pass
        elif type(t) is ast.PointerType:
            # If a pointed type is detected, stop structural
            # equivalence:
            self.check_type(t.ptype, first=False, byname=True)
        elif type(t) is ast.StructureType:
            self.got_types.add(t)
            # Setup offsets of fields. Is this the right place?:
            offset = 0
            for struct_member in t.mems:
                self.check_type(struct_member.typ, first=False)
                struct_member.offset = offset
                offset = offset + self.size_of(struct_member.typ)
        elif type(t) is ast.ArrayType:
            self.check_type(t.element_type, first=False)
        elif type(t) is ast.DefinedType:
            pass
        else:
            raise NotImplementedError('{} not implemented'.format(type(t)))