diff python/c3/analyse.py @ 287:1c7c1e619be8

File movage
author Windel Bouwman
date Thu, 21 Nov 2013 11:57:27 +0100
parents 6f2423df0675
children a747a45dcd78
line wrap: on
line diff
--- a/python/c3/analyse.py	Fri Nov 15 13:52:32 2013 +0100
+++ b/python/c3/analyse.py	Thu Nov 21 11:57:27 2013 +0100
@@ -1,14 +1,14 @@
 import logging
 from .visitor import Visitor
 from .astnodes import *
-from .scope import Scope, topScope
-from .typecheck import theType
+from .scope import *
+
 
 class Analyzer:
     """ 
         Context handling is done here.
         Scope is attached to the correct modules.
-        This class checks names and references 
+        This class checks names and references.
     """
     def __init__(self, diag):
         self.diag = diag
@@ -18,7 +18,7 @@
         self.logger.info('Checking package {}'.format(pkg.name))
         self.ok = True
         visitor = Visitor()
-        # Prepare top level scope:
+        # Prepare top level scope and set scope to all objects:
         self.scopeStack = [topScope]
         modScope = Scope(self.CurrentScope)
         self.scopeStack.append(modScope)
@@ -34,7 +34,6 @@
             for x in ip.declarations:
                 modScope.addSymbol(x)
         visitor.visit(pkg, self.findRefs)
-        visitor.visit(pkg, self.sanity)
         return self.ok
 
     def error(self, msg, loc=None):
@@ -82,12 +81,9 @@
                 s.addRef(None)
             return s
         else:
-            self.ok = False
-            msg = 'Cannot resolve name {0}'.format(d.tname)
-            self.diag.error(msg, d.loc)
+            self.error('Cannot resolve name {0}'.format(d.tname), d.loc)
 
     def resolveType(self, t, scope):
-        # TODO: what about structs?
         if type(t) is PointerType:
             t.ptype = self.resolveType(t.ptype, scope)
             return t
@@ -131,9 +127,192 @@
         elif type(sym) is DefinedType:
             sym.typ = self.resolveType(sym.typ, sym.scope)
 
-    def sanity(self, sym):
-        if type(sym) is FunctionType:
+# Type checking:
+
+def theType(t):
+    """
+        Recurse until a 'real' type is found
+    """
+    if type(t) is DefinedType:
+        return theType(t.typ)
+    return t
+
+def equalTypes(a, b):
+    """
+        Compare types a and b for equality.
+        Not equal until proven otherwise.
+    """
+    # Recurse into named types:
+    a = theType(a)
+    b = theType(b)
+
+    # Compare for structural equivalence:
+    if type(a) is type(b):
+        if type(a) is BaseType:
+            return a.name == b.name
+        elif type(a) is PointerType:
+            return equalTypes(a.ptype, b.ptype)
+        elif type(a) is StructureType:
+            if len(a.mems) != len(b.mems):
+                return False
+            for amem, bmem in zip(a.mems, b.mems):
+                if not equalTypes(amem.typ, bmem.typ):
+                    return False
+            return True
+        else:
+            raise Exception('Type compare for {} not implemented'.format(type(a)))
+    return False
+
+def canCast(fromT, toT):
+    fromT = theType(fromT)
+    toT = theType(toT)
+    if isinstance(fromT, PointerType) and isinstance(toT, PointerType):
+        return True
+    elif fromT is intType and isinstance(toT, PointerType):
+        return True
+    return False
+
+def expectRval(s):
+    # TODO: solve this better
+    s.expect_rvalue = True
+
+class TypeChecker:
+    def __init__(self, diag):
+        self.diag = diag
+
+    def error(self, msg, loc):
+        """
+            Wrapper that registers the message and marks the result invalid
+        """
+        self.diag.error(msg, loc)
+        self.ok = False
+
+    def checkPackage(self, pkg):
+        self.ok = True
+        visitor = Visitor()
+        visitor.visit(pkg, f_post=self.check2)
+        return self.ok
+
+    def check2(self, sym):
+        if type(sym) in [IfStatement, WhileStatement]:
+            if not equalTypes(sym.condition.typ, boolType):
+                msg = 'Condition must be of type {}'.format(boolType)
+                self.error(msg, sym.condition.loc)
+        elif type(sym) is Assignment:
+            l, r = sym.lval, sym.rval
+            if not equalTypes(l.typ, r.typ):
+                msg = 'Cannot assign {} to {}'.format(r.typ, l.typ)
+                self.error(msg, sym.loc)
+            if not l.lvalue:
+                self.error('No valid lvalue {}'.format(l), l.loc)
+            #if sym.rval.lvalue:
+            #    self.error('Right hand side must be an rvalue', sym.rval.loc)
+            expectRval(sym.rval)
+        elif type(sym) is ReturnStatement:
             pass
-        elif type(sym) is Function:
+        elif type(sym) is FunctionCall:
+            # Check arguments:
+            ngiv = len(sym.args)
+            ptypes = sym.proc.typ.parametertypes
+            nreq = len(ptypes)
+            if ngiv != nreq:
+               self.error('Function {2}: {0} arguments required, {1} given'.format(nreq, ngiv, sym.proc.name), sym.loc)
+            else:
+               for a, at in zip(sym.args, ptypes):
+                  expectRval(a)
+                  if not equalTypes(a.typ, at):
+                     self.error('Got {0}, expected {1}'.format(a.typ, at), a.loc)
+            # determine return type:
+            sym.typ = sym.proc.typ.returntype
+        elif type(sym) is VariableUse:
+            sym.lvalue = True
+            if isinstance(sym.target, Variable):
+                sym.typ = sym.target.typ
+            else:
+                print('warning {} has no target, defaulting to int'.format(sym))
+                sym.typ = intType
+        elif type(sym) is Literal:
+            sym.lvalue = False
+            if type(sym.val) is int:
+                sym.typ = intType
+            elif type(sym.val) is float:
+                sym.typ = doubleType
+            elif type(sym.val) is bool:
+                sym.typ = boolType
+            else:
+                raise Exception('Unknown literal type'.format(sym.val))
+        elif type(sym) is Unop:
+            if sym.op == '&':
+                sym.typ = PointerType(sym.a.typ)
+                sym.lvalue = False
+            else:
+                raise Exception('Unknown unop {0}'.format(sym.op))
+        elif type(sym) is Deref:
+            # pointer deref
+            sym.lvalue = True
+            # check if the to be dereferenced variable is a pointer type:
+            ptype = theType(sym.ptr.typ)
+            if type(ptype) is PointerType:
+                sym.typ = ptype.ptype
+            else:
+                self.error('Cannot dereference non-pointer type {}'.format(ptype), sym.loc)
+                sym.typ = intType
+        elif type(sym) is FieldRef:
+            basetype = sym.base.typ
+            sym.lvalue = sym.base.lvalue
+            basetype = theType(basetype)
+            if type(basetype) is StructureType:
+                if basetype.hasField(sym.field):
+                    sym.typ = basetype.fieldType(sym.field)
+                else:
+                    self.error('{} does not contain field {}'.format(basetype, sym.field), sym.loc)
+                    sym.typ = intType
+            else:
+                self.error('Cannot select field {} of non-structure type {}'.format(sym.field, basetype), sym.loc)
+                sym.typ = intType
+        elif type(sym) is Binop:
+            sym.lvalue = False
+            if sym.op in ['+', '-', '*', '/', '<<', '>>', '|', '&']:
+                expectRval(sym.a)
+                expectRval(sym.b)
+                if equalTypes(sym.a.typ, sym.b.typ):
+                   if equalTypes(sym.a.typ, intType):
+                      sym.typ = sym.a.typ
+                   else:
+                      self.error('Can only add integers', sym.loc)
+                      sym.typ = intType
+                else:
+                   # assume void here? TODO: throw exception!
+                   sym.typ = intType
+                   self.error('Types unequal {} != {}'.format(sym.a.typ, sym.b.typ), sym.loc)
+            elif sym.op in ['>', '<', '==', '<=', '>=']:
+                expectRval(sym.a)
+                expectRval(sym.b)
+                sym.typ = boolType
+                if not equalTypes(sym.a.typ, sym.b.typ):
+                   self.error('Types unequal {} != {}'.format(sym.a.typ, sym.b.typ), sym.loc)
+            elif sym.op in ['or', 'and']:
+                sym.typ = boolType
+                if not equalTypes(sym.a.typ, boolType):
+                   self.error('Must be {0}'.format(boolType), sym.a.loc)
+                if not equalTypes(sym.b.typ, boolType):
+                   self.error('Must be {0}'.format(boolType), sym.b.loc)
+            else:
+                raise Exception('Unknown binop {0}'.format(sym.op))
+        elif isinstance(sym, Variable):
+            # check initial value type:
+            # TODO
             pass
-
+        elif type(sym) is TypeCast:
+            if canCast(sym.a.typ, sym.to_type):
+                sym.typ = sym.to_type
+            else:
+                self.error('Cannot cast {} to {}'.format(sym.a.typ, sym.to_type), sym.loc)
+                sym.typ = intType
+        elif type(sym) is Constant:
+            if not equalTypes(sym.typ, sym.value.typ):
+                self.error('Cannot assign {0} to {1}'.format(sym.value.typ, sym.typ), sym.loc)
+        elif type(sym) in [CompoundStatement, Package, Function, FunctionType, ExpressionStatement, DefinedType]:
+            pass
+        else:
+            raise NotImplementedError('Unknown type check {0}'.format(sym))