diff python/c3/typecheck.py @ 226:240111e0456f

Work on named types
author Windel Bouwman
date Fri, 12 Jul 2013 17:25:31 +0200
parents 1c7364bd74c7
children 82dfe6a32717
line wrap: on
line diff
--- a/python/c3/typecheck.py	Thu Jul 11 07:42:30 2013 +0200
+++ b/python/c3/typecheck.py	Fri Jul 12 17:25:31 2013 +0200
@@ -3,13 +3,19 @@
 from .visitor import Visitor
 
 def equalTypes(a, b):
-   """ Compare types a and b for equality. Not equal until proven otherwise. """
-   if type(a) is type(b):
-      if type(a) is BaseType:
-         return a.name == b.name
-      elif type(a) is PointerType:
-        return equalTypes(a.ptype, b.ptype)
-   return False
+    """ Compare types a and b for equality. Not equal until proven otherwise. """
+    # Recurse into named types:
+    if type(a) is DefinedType:
+        return equalTypes(a.typ, b)
+    if type(b) is DefinedType:
+        return equalTypes(a, b.typ)
+    # Compare for structural equivalence:
+    if type(a) is type(b):
+        if type(a) is BaseType:
+            return a.name == b.name
+        elif type(a) is PointerType:
+            return equalTypes(a.ptype, b.ptype)
+    return False
 
 def canCast(fromT, toT):
     if isinstance(fromT, PointerType) and isinstance(toT, PointerType):
@@ -89,19 +95,24 @@
             # pointer deref
             sym.lvalue = True
             # check if the to be dereferenced variable is a pointer type:
-            if type(sym.ptr.typ) is PointerType:
-                sym.typ = sym.ptr.typ.ptype
+            ptype = sym.ptr.typ
+            if type(ptype) is DefinedType:
+                ptype = ptype.typ
+            if type(ptype) is PointerType:
+                sym.typ = ptype.ptype
             else:
-                self.error('Cannot dereference non-pointer type {}'.format(sym.ptr.typ), sym.loc)
+                self.error('Cannot dereference non-pointer type {}'.format(ptype), sym.loc)
                 sym.typ = intType
         elif type(sym) is FieldRef:
             basetype = sym.base.typ
             sym.lvalue = True
+            if type(basetype) is DefinedType:
+                basetype = basetype.typ
             if type(basetype) is StructureType:
                 if basetype.hasField(sym.field):
                     sym.typ = basetype.fieldType(sym.field)
                 else:
-                    self.error('{} does not contain field {}'.format(basetype, symfield), sym.loc)
+                    self.error('{} does not contain field {}'.format(basetype, sym.field), sym.loc)
                     sym.typ = intType
             else:
                 self.error('Cannot select field {} of non-structure type {}'.format(sym.field, basetype), sym.loc)