changeset 215:c1ccb1cb4cef

Major changes in c3 frontend
author Windel Bouwman
date Fri, 05 Jul 2013 13:00:03 +0200
parents 6875360e8390
children 57c032c5e753
files python/c3/__init__.py python/c3/analyse.py python/c3/astnodes.py python/c3/astprinter.py python/c3/builder.py python/c3/parser.py python/c3/typecheck.py python/c3/visitor.py python/ppci/errors.py python/testc3.py
diffstat 10 files changed, 206 insertions(+), 145 deletions(-) [+]
line wrap: on
line diff
--- a/python/c3/__init__.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/c3/__init__.py	Fri Jul 05 13:00:03 2013 +0200
@@ -6,7 +6,6 @@
 # Convenience imports:
 
 from .parser import Parser
-from .semantics import Semantics
 from .typecheck import TypeChecker
 from .analyse import Analyzer
 from .codegenerator import CodeGenerator
--- a/python/c3/analyse.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/c3/analyse.py	Fri Jul 05 13:00:03 2013 +0200
@@ -1,40 +1,86 @@
 from .visitor import Visitor
 from .astnodes import *
+from .scope import Scope, topScope
 
 class Analyzer:
-   """ This class checks names and references """
-   def __init__(self, diag):
-      self.diag = diag
-      self.visitor = Visitor(self.a1, self.analyze)
+    """ 
+        Context handling is done here.
+        Scope is attached to the correct modules.
+        This class checks names and references 
+    """
+    def __init__(self, diag):
+        self.diag = diag
 
-   def analyzePackage(self, pkg):
+    def analyzePackage(self, pkg):
         self.ok = True
-        self.visitor.visit(pkg)
+        visitor = Visitor()
+        # Prepare top level scope:
+        self.curScope = topScope
+        visitor.visit(pkg, self.enterScope, self.quitScope)
+        del self.curScope
+        visitor.visit(pkg, self.findRefs)
+        visitor.visit(pkg, self.sanity)
         return self.ok
-   def resolveDesignator(self, d, referee=None):
-      assert type(d) is Designator
-      if d.scope.hasSymbol(d.tname):
-         s = d.scope.getSymbol(d.tname)
-         if hasattr(s, 'addRef'):
-            # TODO: make this nicer
-            s.addRef(referee)
-         return s
-      else:
-         self.ok = False
-         msg = 'Cannot resolve name {0}'.format(d.tname)
-         self.diag.error(msg, d.loc)
-   def a1(self, sym):
-      pass
-   def analyze(self, sym):
-      if type(sym) in [Variable, Constant]:
-         sym.typ = self.resolveDesignator(sym.typ, sym)
-      elif type(sym) is Function:
-         pass
-      elif type(sym) is VariableUse:
-         sym.target = self.resolveDesignator(sym.target, sym)
-      elif type(sym) is FunctionCall:
-         sym.proc = self.resolveDesignator(sym.proc, sym)
-      elif type(sym) is FunctionType:
-         sym.returntype = self.resolveDesignator(sym.returntype)
-         sym.parametertypes = [self.resolveDesignator(pt) for pt in sym.parametertypes]
+
+    def error(self, msg, loc=None):
+        self.ok = False
+        self.diag.error(msg, loc)
+
+    # Scope creation:
+    def addSymbol(self, sym):
+        if self.curScope.hasSymbol(sym.name):
+            self.error('Redefinition of {0}'.format(sym.name), sym.loc)
+        else:
+            self.curScope.addSymbol(sym)
+
+    def enterScope(self, sym):
+        # Distribute the scope:
+        sym.scope = self.curScope
+
+        # Add symbols to current scope:
+        if isinstance(sym, Symbol):
+            self.addSymbol(sym)
+
+        # Create subscope:
+        if type(sym) in [Package, Function]:
+            self.curScope = Scope(self.curScope)
+
+    def quitScope(self, sym):
+        # Pop out of scope:
+        if type(sym) in [Package, Function]:
+            self.curScope = self.curScope.parent
 
+    # Reference fixups:
+    def resolveDesignator(self, d, scope):
+        assert type(d) is Designator
+        assert type(scope) is Scope
+        if scope.hasSymbol(d.tname):
+            s = scope.getSymbol(d.tname)
+            if hasattr(s, 'addRef'):
+                # TODO: make this nicer
+                s.addRef(None)
+            return s
+        else:
+            self.ok = False
+            msg = 'Cannot resolve name {0}'.format(d.tname)
+            self.diag.error(msg, d.loc)
+
+    def findRefs(self, sym):
+        if type(sym) in [Variable, Constant]:
+             sym.typ = self.resolveDesignator(sym.typ, sym.scope)
+        elif type(sym) is VariableUse:
+             sym.target = self.resolveDesignator(sym.target, sym.scope)
+        elif type(sym) is FunctionCall:
+             sym.proc = self.resolveDesignator(sym.proc, sym.scope)
+        elif type(sym) is Function:
+            # Checkup function type:
+            ft = sym.typ
+            ft.returntype = self.resolveDesignator(ft.returntype, sym.scope)
+            ft.parametertypes = [self.resolveDesignator(pt, sym.scope) for pt in ft.parametertypes]
+
+    def sanity(self, sym):
+        if type(sym) is FunctionType:
+            pass
+        elif type(sym) is Function:
+            pass
+
--- a/python/c3/astnodes.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/c3/astnodes.py	Fri Jul 05 13:00:03 2013 +0200
@@ -106,6 +106,7 @@
     def __init__(self, name, loc):
         super().__init__(name)
         self.loc = loc
+        self.declarations = []
 
     def __repr__(self):
         return '{}'.format(self.name)
--- a/python/c3/astprinter.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/c3/astprinter.py	Fri Jul 05 13:00:03 2013 +0200
@@ -1,13 +1,10 @@
-from .astnodes import *
-from .scope import *
 from .visitor import Visitor
 
 class AstPrinter:
-   def __init__(self):
-      self.visitor = Visitor(self.print1, self.print2)
    def printAst(self, pkg):
       self.indent = 0
-      self.visitor.visit(pkg)
+      visitor = Visitor()
+      visitor.visit(pkg, self.print1, self.print2)
    def print1(self, node):
       print(' ' * self.indent + str(node))
       self.indent += 2
--- a/python/c3/builder.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/c3/builder.py	Fri Jul 05 13:00:03 2013 +0200
@@ -1,5 +1,5 @@
 import ppci
-from . import Parser, Semantics, TypeChecker, Analyzer, CodeGenerator, AstPrinter
+from . import Parser, TypeChecker, Analyzer, CodeGenerator
 
 class Builder:
     """ 
@@ -12,13 +12,13 @@
       self.tc = TypeChecker(diag)
       self.al = Analyzer(diag)
       self.cg = CodeGenerator()
-      self.ap = AstPrinter()
     def build(self, src):
       """ Create IR-code from sources """
       pkg = self.parser.parseSource(src)
       if not pkg:
             return
       self.pkg = pkg
+      # TODO: merge the two below?
       if not self.al.analyzePackage(pkg):
             return
       if not self.tc.checkPackage(pkg):
--- a/python/c3/parser.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/c3/parser.py	Fri Jul 05 13:00:03 2013 +0200
@@ -49,6 +49,8 @@
     def initLex(self, source):
       self.tokens = lexer.tokenize(source) # Lexical stage
       self.token = self.tokens.__next__()
+    def addDeclaration(self, decl):
+        self.currentPart.declarations.append(decl)
     
     def parseUses(self):
         pass
@@ -58,6 +60,7 @@
       name = self.Consume('ID')
       self.Consume(';')
       self.mod = astnodes.Package(name.val, name.loc)
+      self.currentPart = self.mod
       self.parseUses()
       # TODO: parse uses
       while self.Peak != 'END':
@@ -125,6 +128,7 @@
          v.loc = name.loc
          if self.hasConsumed('='):
             v.ival = self.parseExpression()
+         self.addDeclaration(v)
       parseVar()
       while self.hasConsumed(','):
          parseVar()
@@ -150,6 +154,9 @@
       returntype = self.parseTypeSpec()
       fname = self.Consume('ID').val
       f = astnodes.Function(fname, loc)
+      self.addDeclaration(f)
+      savePart = self.currentPart
+      self.currentPart = f
       self.Consume('(')
       parameters = []
       if not self.hasConsumed(')'):
@@ -158,12 +165,16 @@
             name = self.Consume('ID')
             param = astnodes.Variable(name.val, typ)
             param.loc = name.loc
+            self.addDeclaration(param)
             parameters.append(param)
          parseParameter()
          while self.hasConsumed(','):
             parseParameter()
          self.Consume(')')
-      body = self.parseCompoundStatement()
+      paramtypes = [p.typ for p in parameters]
+      f.typ = astnodes.FunctionType(paramtypes, returntype)
+      f.body = self.parseCompoundStatement()
+      self.currentPart = savePart
 
     # Statements:
     def parseAssignment(self, lval):
--- a/python/c3/typecheck.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/c3/typecheck.py	Fri Jul 05 13:00:03 2013 +0200
@@ -12,17 +12,15 @@
 class TypeChecker:
    def __init__(self, diag):
       self.diag = diag
-      self.visitor = Visitor(self.precheck, self.check2)
    def error(self, msg, loc):
         """ Wrapper that registers the message and marks the result invalid """
         self.diag.error(msg, loc)
         self.ok = False
    def checkPackage(self, pkg):
         self.ok = True
-        self.visitor.visit(pkg)
+        visitor = Visitor()
+        visitor.visit(pkg, f_post=self.check2)
         return self.ok
-   def precheck(self, sym):
-      pass
    def check2(self, sym):
       if type(sym) is Function:
          pass
--- a/python/c3/visitor.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/c3/visitor.py	Fri Jul 05 13:00:03 2013 +0200
@@ -1,49 +1,58 @@
 from .astnodes import *
 
 class Visitor:
-   """ Visitor that visits all nodes in the ast and runs the function 'f' """
-   def __init__(self, f1, f2):
-      self.f1 = f1
-      self.f2 = f2
-   def visit(self, node):
+    """ 
+        Visitor that can visit all nodes in the AST
+        and run pre and post functions.
+    """
+    def visit(self, node, f_pre=None, f_post=None):
+        self.f_pre = f_pre
+        self.f_post = f_post
+        self.do(node)
+
+    def do(self, node):
       # Run visitor:
-      self.f1(node)
+      if self.f_pre:
+            self.f_pre(node)
+
       # Descent into subnodes:
       if type(node) is Package:
-         for s in node.scope:
-            self.visit(s)
+            for decl in node.declarations:
+                self.do(decl)
       elif type(node) is Function:
-         for s in node.scope:
-            self.visit(s)
-         self.visit(node.typ)
-         self.visit(node.body)
+            for s in node.declarations:
+                self.do(s)
+            self.do(node.body)
       elif type(node) is CompoundStatement:
          for s in node.statements:
-            self.visit(s)
+            self.do(s)
       elif type(node) is IfStatement:
-         self.visit(node.condition)
-         self.visit(node.truestatement)
-         self.visit(node.falsestatement)
+         self.do(node.condition)
+         self.do(node.truestatement)
+         self.do(node.falsestatement)
       elif type(node) is FunctionCall:
          for arg in node.args:
-            self.visit(arg)
+            self.do(arg)
       elif type(node) is Assignment:
-         self.visit(node.lval)
-         self.visit(node.rval)
+         self.do(node.lval)
+         self.do(node.rval)
       elif type(node) is ReturnStatement:
-         self.visit(node.expr)
+         self.do(node.expr)
       elif type(node) is Binop:
-         self.visit(node.a)
-         self.visit(node.b)
+         self.do(node.a)
+         self.do(node.b)
       elif type(node) is Constant:
-         self.visit(node.value)
+         self.do(node.value)
       elif type(node) in [EmptyStatement, VariableUse, Variable, Literal, FunctionType]:
          # Those nodes do not have child nodes.
          pass
       elif type(node) is WhileStatement:
-         self.visit(node.condition)
-         self.visit(node.dostatement)
+         self.do(node.condition)
+         self.do(node.statement)
       else:
-         raise Exception('UNK visit "{0}"'.format(node))
-      self.f2(node)
+           raise Exception('Could not visit "{0}"'.format(node))
 
+      # run post function
+      if self.f_post:
+            self.f_post(node)
+
--- a/python/ppci/errors.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/ppci/errors.py	Fri Jul 05 13:00:03 2013 +0200
@@ -11,9 +11,14 @@
         self.loc = loc
         if loc:
             assert type(loc) is SourceLocation, '{0} must be SourceLocation'.format(type(loc))
+            self.row = loc.row
+            self.col = loc.col
+        else:
+            self.row = self.col = None
+
     def __repr__(self):
-        if self.loc:
-            return 'Compilererror: "{0}" at row {1}'.format(self.msg, self.loc.row)
+        if self.row:
+            return 'Compilererror: "{0}" at row {1}'.format(self.msg, self.row)
         else:
             return 'Compilererror: "{0}"'.format(self.msg)
 
@@ -24,7 +29,7 @@
         print('Error: {0}'.format(e.msg))
      else:
         lines = source.split('\n')
-        ro, co = e.loc.row, e.loc.col
+        ro, co = e.row, e.col
         prerow = ro - 2
         if prerow < 1:
            prerow = 1
@@ -45,12 +50,16 @@
 class DiagnosticsManager:
    def __init__(self):
       self.diags = []
+
    def addDiag(self, d):
       self.diags.append(d)
+
    def error(self, msg, loc):
       self.addDiag(CompilerError(msg, loc))
+
    def clear(self):
       del self.diags[:]
+
    def printErrors(self, src):
       if len(self.diags) > 0:
          print('{0} Errors'.format(len(self.diags)))
--- a/python/testc3.py	Fri Jul 05 11:18:58 2013 +0200
+++ b/python/testc3.py	Fri Jul 05 13:00:03 2013 +0200
@@ -4,7 +4,10 @@
 
 testsrc = """package test;
 
-var u32 c, d;
+/*
+ demo of the source that is correct :)
+*/ 
+var int c, d;
 var double e;
 var int f;
 
@@ -12,9 +15,9 @@
 
 function void test1() 
 {
-    var u32 bdd;
+    var int bdd;
     var int a = 10;
-    bd = 20;
+    bdd = 20;
     var int buf;
     var int i;
     i = 2;
@@ -31,21 +34,23 @@
     t2(2, 3);
 }
 
-function int t2(u32 a, u32 b)
+function int t2(int a, int b)
 {
    if (a > 0)
    {
-      a = 2 + t2(a - 1, 1.0);
+      a = 2 + t2(a - 1, 10);
    }
 
    return a + b;
 }
 
+var int a, b;
+
 function int t3(int aap, int blah)
 {
-   if (a > blah and blah < 45 + 33 or 33 > aap or 6 and true)
+   if (a > blah and blah < 45 + 33 or 33 > aap or 6 > 2 and true)
    {
-      a = 2 + t2(a - 1);
+      a = 2 + t2(a - 1, 0);
    }
 
    return a + b;
@@ -82,12 +87,30 @@
         self.assertSequenceEqual([tok.typ for tok in c3.lexer.tokenize(snippet)], toks)
 
 class testBuilder(unittest.TestCase):
-   def setUp(self):
-      self.diag = ppci.DiagnosticsManager()
-      self.builder = c3.Builder(self.diag)
-   def testSrc(self):
-      self.builder.build(testsrc)
-   def testFunctArgs(self):
+    def setUp(self):
+        self.diag = ppci.DiagnosticsManager()
+        self.builder = c3.Builder(self.diag)
+        self.diag.clear()
+
+    def testSrc(self):
+        self.expectOK(testsrc)
+
+    def expectErrors(self, snippet, rows):
+        """ Helper to test for expected errors on rows """
+        ircode = self.builder.build(snippet)
+        actualErrors = [err.row for err in self.diag.diags]
+        if rows != actualErrors:
+            self.diag.printErrors(snippet)
+        self.assertSequenceEqual(rows, actualErrors)
+        self.assertFalse(ircode)
+
+    def expectOK(self, snippet):
+        ircode = self.builder.build(snippet)
+        if not ircode:
+            self.diag.printErrors(snippet)
+        self.assertTrue(ircode)
+
+    def testFunctArgs(self):
       snippet = """
          package testargs;
          function void t2(int a, double b)
@@ -97,13 +120,9 @@
             t2(1, 1.2);
          }
       """
-      self.diag.clear()
-      ir = self.builder.build(snippet)
-      assert len(self.diag.diags) == 2
-      self.assertEqual(5, self.diag.diags[0].loc.row)
-      self.assertEqual(6, self.diag.diags[1].loc.row)
+      self.expectErrors(snippet, [5, 6])
 
-   def testExpressions(self):
+    def testExpressions(self):
       snippet = """
          package test;
          function void t(int a, double b)
@@ -116,41 +135,28 @@
             c = b > 1;
          }
       """
-      self.diag.clear()
-      ircode = self.builder.build(snippet)
-      self.assertEqual(len(self.diag.diags), 3)
-      self.assertEqual(self.diag.diags[0].loc.row, 8)
-      self.assertEqual(self.diag.diags[1].loc.row, 9)
-      self.assertEqual(self.diag.diags[2].loc.row, 10)
-      self.assertFalse(ircode)
+      self.expectErrors(snippet, [8, 9, 10])
 
-   def testEmpty(self):
+    def testEmpty(self):
       snippet = """
       package A
       """
-      ircode = self.builder.build(snippet)
-      self.assertFalse(ircode)
+      self.expectErrors(snippet, [3])
 
-   def testEmpty2(self):
+    def testEmpty2(self):
       snippet = ""
-      self.diag.clear()
-      ircode = self.builder.build(snippet)
-      self.assertFalse(ircode)
+      self.expectErrors(snippet, [1])
 
-   def testRedefine(self):
+    def testRedefine(self):
       snippet = """
       package test;
       var int a;
       var int b;
       var int a;
       """
-      self.diag.clear()
-      ircode = self.builder.build(snippet)
-      self.assertFalse(ircode)
-      self.assertEqual(len(self.diag.diags), 1)
-      self.assertEqual(self.diag.diags[0].loc.row, 5)
+      self.expectErrors(snippet, [5])
 
-   def testWhile(self):
+    def testWhile(self):
       snippet = """
       package tstwhile;
       var int a;
@@ -172,12 +178,10 @@
          }
       }
       """
-      ircode = self.builder.build(snippet)
-      if not ircode:
-        self.diag.printErrors(snippet)
-      self.assertTrue(ircode)
-   def testIf(self):
-      snippet = """
+      self.expectOK(snippet)
+
+    def testIf(self):
+        snippet = """
       package tstIFF;
       function void t(int b)
       {
@@ -197,14 +201,11 @@
 
          return b;
       }
-      """
-      ircode = self.builder.build(snippet)
-      if not ircode:
-        self.diag.printErrors(snippet)
-      self.assertTrue(ircode)
+        """
+        self.expectOK(snippet)
    
-   @unittest.skip 
-   def testPointerType(self):
+    @unittest.skip 
+    def testPointerType(self):
         snippet = """
          package testpointer;
          var int* pa;
@@ -213,14 +214,10 @@
             *pa = 22;
          }
         """
-        self.diag.clear()
-        ircode = self.builder.build(snippet)
-        if not ircode:
-            self.diag.printErrors(snippet)
-        self.assertTrue(ircode)
+        self.expectOK(snippet)
 
-   @unittest.skip 
-   def testComplexType(self):
+    @unittest.skip 
+    def testComplexType(self):
         snippet = """
          package testpointer;
          type int my_int;
@@ -248,15 +245,11 @@
             mxp->P1.x = a * x->P1.y;
          }
         """
-        self.diag.clear()
-        ircode = self.builder.build(snippet)
-        if not ircode:
-            self.diag.printErrors(snippet)
-        self.assertTrue(ircode)
+        self.expectOK(snippet)
 
-   def test2(self):
+    def test2(self):
         # testsrc2 is valid code:
-        testsrc2 = """
+        snippet = """
         package test2;
 
         function void tst()
@@ -278,9 +271,7 @@
         }
 
         """
-        self.diag.clear()
-        ir = self.builder.build(testsrc2)
-        self.assertTrue(ir)
+        self.expectOK(snippet)
 
 if __name__ == '__main__':
    unittest.main()