diff python/c3/analyse.py @ 288:a747a45dcd78

Various styling work
author Windel Bouwman
date Thu, 21 Nov 2013 14:26:13 +0100
parents 1c7c1e619be8
children bd2593de3ff8
line wrap: on
line diff
--- a/python/c3/analyse.py	Thu Nov 21 11:57:27 2013 +0100
+++ b/python/c3/analyse.py	Thu Nov 21 14:26:13 2013 +0100
@@ -4,47 +4,37 @@
 from .scope import *
 
 
-class Analyzer:
-    """ 
-        Context handling is done here.
-        Scope is attached to the correct modules.
-        This class checks names and references.
-    """
+class C3Pass:
     def __init__(self, diag):
         self.diag = diag
         self.logger = logging.getLogger('c3')
-
-    def analyzePackage(self, pkg, packageProvider):
-        self.logger.info('Checking package {}'.format(pkg.name))
         self.ok = True
-        visitor = Visitor()
-        # Prepare top level scope and set scope to all objects:
-        self.scopeStack = [topScope]
-        modScope = Scope(self.CurrentScope)
-        self.scopeStack.append(modScope)
-        visitor.visit(pkg, self.enterScope, self.quitScope)
-        del self.scopeStack
-
-        # Handle imports:
-        for i in pkg.imports:
-            ip = packageProvider.getPackage(i)
-            if not ip:
-                self.error('Cannot import {}'.format(i))
-                continue
-            for x in ip.declarations:
-                modScope.addSymbol(x)
-        visitor.visit(pkg, self.findRefs)
-        return self.ok
+        self.visitor = Visitor()
 
     def error(self, msg, loc=None):
         self.ok = False
         self.diag.error(msg, loc)
 
+    def visit(self, pkg, pre, post):
+        self.visitor.visit(pkg, pre, post)
+
+
+class AddScope(C3Pass):
+    """ Scope is attached to the correct modules. """
+    def addScope(self, pkg):
+        self.logger.info('Adding scoping to package {}'.format(pkg.name))
+        # Prepare top level scope and set scope to all objects:
+        self.scopeStack = [topScope]
+        modScope = Scope(self.CurrentScope)
+        self.scopeStack.append(modScope)
+        self.visit(pkg, self.enterScope, self.quitScope)
+        assert len(self.scopeStack) == 2
+        return self.ok
+
     @property
     def CurrentScope(self):
         return self.scopeStack[-1]
 
-    # Scope creation:
     def addSymbol(self, sym):
         if self.CurrentScope.hasSymbol(sym.name):
             self.error('Redefinition of {0}'.format(sym.name), sym.loc)
@@ -70,6 +60,35 @@
         if type(sym) in [Package, Function]:
             self.scopeStack.pop(-1)
 
+
+class Analyzer(C3Pass):
+    """
+        Context handling is done here.
+        Scope is attached to the correct modules.
+        This class checks names and references.
+    """
+
+    def analyzePackage(self, pkg, packageDict):
+        self.ok = True
+        # Prepare top level scope and set scope to all objects:
+        AddScope(self.diag).addScope(pkg)
+
+        self.logger.info('Resolving imports for package {}'.format(pkg.name))
+        # Handle imports:
+        for i in pkg.imports:
+            ip = packageDict[i]
+            if not ip:
+                self.error('Cannot import {}'.format(i))
+                continue
+            pkg.scope.addSymbol(ip)
+        FixRefs(self.diag).fixRefs(pkg)
+        return self.ok
+
+
+class FixRefs(C3Pass):
+    def fixRefs(self, pkg):
+        self.visitor.visit(pkg, self.findRefs)
+
     # Reference fixups:
     def resolveDesignator(self, d, scope):
         assert type(d) is Designator, type(d)
@@ -119,7 +138,8 @@
             # Checkup function type:
             ft = sym.typ
             ft.returntype = self.resolveType(ft.returntype, sym.scope)
-            ft.parametertypes = [self.resolveType(pt, sym.scope) for pt in ft.parametertypes]
+            ft.parametertypes = [self.resolveType(pt, sym.scope) for pt in
+                                 ft.parametertypes]
             # Mark local variables:
             for d in sym.declarations:
                 if isinstance(d, Variable):
@@ -127,26 +147,21 @@
         elif type(sym) is DefinedType:
             sym.typ = self.resolveType(sym.typ, sym.scope)
 
+
 # Type checking:
 
 def theType(t):
-    """
-        Recurse until a 'real' type is found
-    """
+    """ Recurse until a 'real' type is found """
     if type(t) is DefinedType:
         return theType(t.typ)
     return t
 
+
 def equalTypes(a, b):
-    """
-        Compare types a and b for equality.
-        Not equal until proven otherwise.
-    """
+    """ Compare types a and b for structural equavalence. """
     # Recurse into named types:
-    a = theType(a)
-    b = theType(b)
+    a, b = theType(a), theType(b)
 
-    # Compare for structural equivalence:
     if type(a) is type(b):
         if type(a) is BaseType:
             return a.name == b.name
@@ -155,14 +170,14 @@
         elif type(a) is StructureType:
             if len(a.mems) != len(b.mems):
                 return False
-            for amem, bmem in zip(a.mems, b.mems):
-                if not equalTypes(amem.typ, bmem.typ):
-                    return False
-            return True
+            return all(equalTypes(am.typ, bm.typ) for am, bm in
+                       zip(a.mems, b.mems))
         else:
-            raise Exception('Type compare for {} not implemented'.format(type(a)))
+            raise NotImplementedError(
+                    'Type compare for {} not implemented'.format(type(a)))
     return False
 
+
 def canCast(fromT, toT):
     fromT = theType(fromT)
     toT = theType(toT)
@@ -172,25 +187,16 @@
         return True
     return False
 
+
 def expectRval(s):
     # TODO: solve this better
     s.expect_rvalue = True
 
-class TypeChecker:
-    def __init__(self, diag):
-        self.diag = diag
 
-    def error(self, msg, loc):
-        """
-            Wrapper that registers the message and marks the result invalid
-        """
-        self.diag.error(msg, loc)
-        self.ok = False
-
+class TypeChecker(C3Pass):
     def checkPackage(self, pkg):
         self.ok = True
-        visitor = Visitor()
-        visitor.visit(pkg, f_post=self.check2)
+        self.visit(pkg, None, self.check2)
         return self.ok
 
     def check2(self, sym):