diff python/c3/typecheck.py @ 230:88a1e0baef65

Added some tests for IR-code
author Windel Bouwman
date Sat, 13 Jul 2013 19:53:44 +0200
parents 7f18ed9b6b7e
children 521567d17388
line wrap: on
line diff
--- a/python/c3/typecheck.py	Sat Jul 13 11:13:01 2013 +0200
+++ b/python/c3/typecheck.py	Sat Jul 13 19:53:44 2013 +0200
@@ -2,16 +2,20 @@
 from .scope import *
 from .visitor import Visitor
 
+def resolveType(t):
+    if type(t) is DefinedType:
+        return resolveType(t.typ)
+    return t
+
 def equalTypes(a, b):
-    """ 
+    """
         Compare types a and b for equality.
         Not equal until proven otherwise.
     """
     # Recurse into named types:
-    if type(a) is DefinedType:
-        return equalTypes(a.typ, b)
-    if type(b) is DefinedType:
-        return equalTypes(a, b.typ)
+    a = resolveType(a)
+    b = resolveType(b)
+
     # Compare for structural equivalence:
     if type(a) is type(b):
         if type(a) is BaseType:
@@ -30,19 +34,25 @@
     return False
 
 def canCast(fromT, toT):
+    fromT = resolveType(fromT)
+    toT = resolveType(toT)
     if isinstance(fromT, PointerType) and isinstance(toT, PointerType):
         return True
     elif fromT is intType and isinstance(toT, PointerType):
         return True
     return False
 
+def expectRval(s):
+    # TODO: solve this better
+    s.expect_rvalue = True
+
 class TypeChecker:
     def __init__(self, diag):
         self.diag = diag
 
     def error(self, msg, loc):
-        """ 
-            Wrapper that registers the message and marks the result invalid 
+        """
+            Wrapper that registers the message and marks the result invalid
         """
         self.diag.error(msg, loc)
         self.ok = False
@@ -56,14 +66,18 @@
     def check2(self, sym):
         if type(sym) in [IfStatement, WhileStatement]:
             if not equalTypes(sym.condition.typ, boolType):
-                self.error('Condition must be of type {0}'.format(boolType), sym.condition.loc)
+                msg = 'Condition must be of type {}'.format(boolType)
+                self.error(msg, sym.condition.loc)
         elif type(sym) is Assignment:
-            if not equalTypes(sym.lval.typ, sym.rval.typ):
-                self.error('Cannot assign {0} to {1}'.format(sym.rval.typ, sym.lval.typ), sym.loc)
-            if not sym.lval.lvalue:
-                self.error('No valid lvalue {}'.format(sym.lval), sym.lval.loc)
+            l, r = sym.lval, sym.rval
+            if not equalTypes(l.typ, r.typ):
+                msg = 'Cannot assign {} to {}'.format(r.typ, l.typ)
+                self.error(msg, sym.loc)
+            if not l.lvalue:
+                self.error('No valid lvalue {}'.format(l), l.loc)
             #if sym.rval.lvalue:
             #    self.error('Right hand side must be an rvalue', sym.rval.loc)
+            expectRval(sym.rval)
         elif type(sym) is ReturnStatement:
             pass
         elif type(sym) is FunctionCall:
@@ -75,6 +89,7 @@
                self.error('Function {2}: {0} arguments required, {1} given'.format(nreq, ngiv, sym.proc.name), sym.loc)
             else:
                for a, at in zip(sym.args, ptypes):
+                  expectRval(a)
                   if not equalTypes(a.typ, at):
                      self.error('Got {0}, expected {1}'.format(a.typ, at), a.loc)
             # determine return type:
@@ -106,9 +121,7 @@
             # pointer deref
             sym.lvalue = True
             # check if the to be dereferenced variable is a pointer type:
-            ptype = sym.ptr.typ
-            if type(ptype) is DefinedType:
-                ptype = ptype.typ
+            ptype = resolveType(sym.ptr.typ)
             if type(ptype) is PointerType:
                 sym.typ = ptype.ptype
             else:
@@ -116,9 +129,8 @@
                 sym.typ = intType
         elif type(sym) is FieldRef:
             basetype = sym.base.typ
-            sym.lvalue = True
-            if type(basetype) is DefinedType:
-                basetype = basetype.typ
+            sym.lvalue = sym.base.lvalue
+            basetype = resolveType(basetype)
             if type(basetype) is StructureType:
                 if basetype.hasField(sym.field):
                     sym.typ = basetype.fieldType(sym.field)
@@ -131,6 +143,8 @@
         elif type(sym) is Binop:
             sym.lvalue = False
             if sym.op in ['+', '-', '*', '/']:
+                expectRval(sym.a)
+                expectRval(sym.b)
                 if equalTypes(sym.a.typ, sym.b.typ):
                    if equalTypes(sym.a.typ, intType):
                       sym.typ = sym.a.typ
@@ -169,12 +183,13 @@
             if canCast(sym.a.typ, sym.to_type):
                 sym.typ = sym.to_type
             else:
-                self.error('Cannot cast {} to {}'.format(sym.a.typ, sym.to_type))
+                self.error('Cannot cast {} to {}'.format(sym.a.typ, sym.to_type), sym.loc)
+                sym.typ = intType
         elif type(sym) is Constant:
             if not equalTypes(sym.typ, sym.value.typ):
                 self.error('Cannot assign {0} to {1}'.format(sym.value.typ, sym.typ), sym.loc)
         elif type(sym) in [CompoundStatement, Package, Function, FunctionType, ExpressionStatement, DefinedType]:
-         pass
+            pass
         else:
             raise Exception('Unknown type check {0}'.format(sym))