view python/c3/analyse.py @ 293:6aa721e7b10b

Try to improve build sequence
author Windel Bouwman
date Thu, 28 Nov 2013 20:39:37 +0100
parents bd2593de3ff8
children
line wrap: on
line source

import logging
from .visitor import Visitor
from .astnodes import *
from .scope import *


class C3Pass:
    def __init__(self, diag):
        self.diag = diag
        self.logger = logging.getLogger('c3')
        self.ok = True
        self.visitor = Visitor()

    def error(self, msg, loc=None):
        self.ok = False
        self.diag.error(msg, loc)

    def visit(self, pkg, pre, post):
        self.visitor.visit(pkg, pre, post)


class AddScope(C3Pass):
    """ Scope is attached to the correct modules. """
    def addScope(self, pkg):
        self.logger.info('Adding scoping to package {}'.format(pkg.name))
        # Prepare top level scope and set scope to all objects:
        self.scopeStack = [topScope]
        modScope = Scope(self.CurrentScope)
        self.scopeStack.append(modScope)
        self.visit(pkg, self.enterScope, self.quitScope)
        assert len(self.scopeStack) == 2
        return self.ok

    @property
    def CurrentScope(self):
        return self.scopeStack[-1]

    def addSymbol(self, sym):
        if self.CurrentScope.hasSymbol(sym.name):
            self.error('Redefinition of {0}'.format(sym.name), sym.loc)
        else:
            self.CurrentScope.addSymbol(sym)

    def enterScope(self, sym):
        # Distribute the scope:
        sym.scope = self.CurrentScope

        # Add symbols to current scope:
        if isinstance(sym, Symbol) or isinstance(sym, DefinedType):
            self.addSymbol(sym)

        # Create subscope:
        if type(sym) in [Package, Function]:
            newScope = Scope(self.CurrentScope)
            self.scopeStack.append(newScope)
            sym.innerScope = self.CurrentScope

    def quitScope(self, sym):
        # Pop out of scope:
        if type(sym) in [Package, Function]:
            self.scopeStack.pop(-1)


class Analyzer(C3Pass):
    """
        Context handling is done here.
        Scope is attached to the correct modules.
        This class checks names and references.
    """

    def analyzePackage(self, pkg, packageDict):
        self.ok = True
        # Prepare top level scope and set scope to all objects:

        self.logger.info('Resolving imports for package {}'.format(pkg.name))
        # Handle imports:
        for i in pkg.imports:
            if i not in packageDict:
                self.error('Cannot import {}'.format(i))
                continue
            ip = packageDict[i]
            pkg.scope.addSymbol(ip)
        FixRefs(self.diag).fixRefs(pkg)
        return self.ok


class FixRefs(C3Pass):
    def fixRefs(self, pkg):
        self.visitor.visit(pkg, self.findRefs)

    # Reference fixups:
    def resolveDesignator(self, d, scope):
        assert isinstance(d, Designator), type(d)
        assert type(scope) is Scope
        if scope.hasSymbol(d.tname):
            s = scope.getSymbol(d.tname)
            if isinstance(d, ImportDesignator):
                if s.innerScope.hasSymbol(d.vname):
                    return s.innerScope.getSymbol(d.vname)
            else:
                if hasattr(s, 'addRef'):
                    # TODO: make this nicer
                    s.addRef(None)
                return s
        else:
            self.error('Cannot resolve name {0}'.format(d.tname), d.loc)

    def resolveImportDesignator(self, d, scope):
        assert isinstance(d, ImportDesignator), type(d)
        assert type(scope) is Scope
        if scope.hasSymbol(d.tname):
            s = scope.getSymbol(d.tname)
            if hasattr(s, 'addRef'):
                # TODO: make this nicer
                s.addRef(None)
            return s
        else:
            self.error('Cannot resolve name {0}'.format(d.tname), d.loc)

    def resolveType(self, t, scope):
        if type(t) is PointerType:
            t.ptype = self.resolveType(t.ptype, scope)
            return t
        elif type(t) is StructureType:
            offset = 0
            for mem in t.mems:
                mem.offset = offset
                mem.typ = self.resolveType(mem.typ, scope)
                offset += theType(mem.typ).bytesize
            t.bytesize = offset
            return t
        elif isinstance(t, Designator):
            t = self.resolveDesignator(t, scope)
            if t:
                return self.resolveType(t, scope)
        elif isinstance(t, Type):
            # Already resolved??
            return t
        else:
            raise Exception('Error resolving type {} {}'.format(t, type(t)))

    def findRefs(self, sym):
        if type(sym) in [Constant] or isinstance(sym, Variable):
            sym.typ = self.resolveType(sym.typ, sym.scope)
        elif type(sym) is TypeCast:
            sym.to_type = self.resolveType(sym.to_type, sym.scope)
        elif type(sym) is VariableUse:
            sym.target = self.resolveDesignator(sym.target, sym.scope)
        elif type(sym) is FunctionCall:
            varuse = sym.proc
            sym.proc = self.resolveDesignator(varuse.target, sym.scope)
        elif type(sym) is Function:
            # Checkup function type:
            ft = sym.typ
            ft.returntype = self.resolveType(ft.returntype, sym.scope)
            ft.parametertypes = [self.resolveType(pt, sym.scope) for pt in
                                 ft.parametertypes]
            # Mark local variables:
            for d in sym.declarations:
                if isinstance(d, Variable):
                    d.isLocal = True
        elif type(sym) is DefinedType:
            sym.typ = self.resolveType(sym.typ, sym.scope)


# Type checking:

def theType(t):
    """ Recurse until a 'real' type is found """
    if type(t) is DefinedType:
        return theType(t.typ)
    return t


def equalTypes(a, b):
    """ Compare types a and b for structural equavalence. """
    # Recurse into named types:
    a, b = theType(a), theType(b)

    if type(a) is type(b):
        if type(a) is BaseType:
            return a.name == b.name
        elif type(a) is PointerType:
            return equalTypes(a.ptype, b.ptype)
        elif type(a) is StructureType:
            if len(a.mems) != len(b.mems):
                return False
            return all(equalTypes(am.typ, bm.typ) for am, bm in
                       zip(a.mems, b.mems))
        else:
            raise NotImplementedError(
                    'Type compare for {} not implemented'.format(type(a)))
    return False


def canCast(fromT, toT):
    fromT = theType(fromT)
    toT = theType(toT)
    if isinstance(fromT, PointerType) and isinstance(toT, PointerType):
        return True
    elif fromT is intType and isinstance(toT, PointerType):
        return True
    return False


def expectRval(s):
    # TODO: solve this better
    s.expect_rvalue = True


class TypeChecker(C3Pass):
    def checkPackage(self, pkg):
        self.ok = True
        self.visit(pkg, None, self.check2)
        return self.ok

    def check2(self, sym):
        if type(sym) in [IfStatement, WhileStatement]:
            if not equalTypes(sym.condition.typ, boolType):
                msg = 'Condition must be of type {}'.format(boolType)
                self.error(msg, sym.condition.loc)
        elif type(sym) is Assignment:
            l, r = sym.lval, sym.rval
            if not equalTypes(l.typ, r.typ):
                msg = 'Cannot assign {} to {}'.format(r.typ, l.typ)
                self.error(msg, sym.loc)
            if not l.lvalue:
                self.error('No valid lvalue {}'.format(l), l.loc)
            #if sym.rval.lvalue:
            #    self.error('Right hand side must be an rvalue', sym.rval.loc)
            expectRval(sym.rval)
        elif type(sym) is ReturnStatement:
            pass
        elif type(sym) is FunctionCall:
            # Check arguments:
            ngiv = len(sym.args)
            ptypes = sym.proc.typ.parametertypes
            nreq = len(ptypes)
            if ngiv != nreq:
               self.error('Function {2}: {0} arguments required, {1} given'.format(nreq, ngiv, sym.proc.name), sym.loc)
            else:
               for a, at in zip(sym.args, ptypes):
                  expectRval(a)
                  if not equalTypes(a.typ, at):
                     self.error('Got {0}, expected {1}'.format(a.typ, at), a.loc)
            # determine return type:
            sym.typ = sym.proc.typ.returntype
        elif type(sym) is VariableUse:
            sym.lvalue = True
            if isinstance(sym.target, Variable):
                sym.typ = sym.target.typ
            else:
                print('warning {} has no target, defaulting to int'.format(sym))
                sym.typ = intType
        elif type(sym) is Literal:
            sym.lvalue = False
            if type(sym.val) is int:
                sym.typ = intType
            elif type(sym.val) is float:
                sym.typ = doubleType
            elif type(sym.val) is bool:
                sym.typ = boolType
            else:
                raise Exception('Unknown literal type'.format(sym.val))
        elif type(sym) is Unop:
            if sym.op == '&':
                sym.typ = PointerType(sym.a.typ)
                sym.lvalue = False
            else:
                raise Exception('Unknown unop {0}'.format(sym.op))
        elif type(sym) is Deref:
            # pointer deref
            sym.lvalue = True
            # check if the to be dereferenced variable is a pointer type:
            ptype = theType(sym.ptr.typ)
            if type(ptype) is PointerType:
                sym.typ = ptype.ptype
            else:
                self.error('Cannot dereference non-pointer type {}'.format(ptype), sym.loc)
                sym.typ = intType
        elif type(sym) is FieldRef:
            basetype = sym.base.typ
            sym.lvalue = sym.base.lvalue
            basetype = theType(basetype)
            if type(basetype) is StructureType:
                if basetype.hasField(sym.field):
                    sym.typ = basetype.fieldType(sym.field)
                else:
                    self.error('{} does not contain field {}'.format(basetype, sym.field), sym.loc)
                    sym.typ = intType
            else:
                self.error('Cannot select field {} of non-structure type {}'.format(sym.field, basetype), sym.loc)
                sym.typ = intType
        elif type(sym) is Binop:
            sym.lvalue = False
            if sym.op in ['+', '-', '*', '/', '<<', '>>', '|', '&']:
                expectRval(sym.a)
                expectRval(sym.b)
                if equalTypes(sym.a.typ, sym.b.typ):
                   if equalTypes(sym.a.typ, intType):
                      sym.typ = sym.a.typ
                   else:
                      self.error('Can only add integers', sym.loc)
                      sym.typ = intType
                else:
                   # assume void here? TODO: throw exception!
                   sym.typ = intType
                   self.error('Types unequal {} != {}'.format(sym.a.typ, sym.b.typ), sym.loc)
            elif sym.op in ['>', '<', '==', '<=', '>=']:
                expectRval(sym.a)
                expectRval(sym.b)
                sym.typ = boolType
                if not equalTypes(sym.a.typ, sym.b.typ):
                   self.error('Types unequal {} != {}'.format(sym.a.typ, sym.b.typ), sym.loc)
            elif sym.op in ['or', 'and']:
                sym.typ = boolType
                if not equalTypes(sym.a.typ, boolType):
                   self.error('Must be {0}'.format(boolType), sym.a.loc)
                if not equalTypes(sym.b.typ, boolType):
                   self.error('Must be {0}'.format(boolType), sym.b.loc)
            else:
                raise Exception('Unknown binop {0}'.format(sym.op))
        elif isinstance(sym, Variable):
            # check initial value type:
            # TODO
            pass
        elif type(sym) is TypeCast:
            if canCast(sym.a.typ, sym.to_type):
                sym.typ = sym.to_type
            else:
                self.error('Cannot cast {} to {}'.format(sym.a.typ, sym.to_type), sym.loc)
                sym.typ = intType
        elif type(sym) is Constant:
            if not equalTypes(sym.typ, sym.value.typ):
                self.error('Cannot assign {0} to {1}'.format(sym.value.typ, sym.typ), sym.loc)
        elif type(sym) in [CompoundStatement, Package, Function, FunctionType, ExpressionStatement, DefinedType]:
            pass
        else:
            raise NotImplementedError('Unknown type check {0}'.format(sym))