changeset 150:4ae0e02599de

Added type check start and analyze phase
author Windel Bouwman
date Fri, 01 Mar 2013 16:53:22 +0100
parents 74241ca312cc
children afc8c0207984
files python/c3/__init__.py python/c3/analyse.py python/c3/astnodes.py python/c3/parser.py python/c3/scope.py python/c3/semantics.py python/c3/typecheck.py python/testc3.py
diffstat 8 files changed, 189 insertions(+), 101 deletions(-) [+]
line wrap: on
line diff
--- a/python/c3/__init__.py	Fri Mar 01 11:43:52 2013 +0100
+++ b/python/c3/__init__.py	Fri Mar 01 16:53:22 2013 +0100
@@ -1,2 +1,9 @@
+
+# Convenience:
+
+from .parser import Parser
+from .semantics import Semantics
+from .typecheck import TypeChecker
+from .analyse import Analyzer
 
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/c3/analyse.py	Fri Mar 01 16:53:22 2013 +0100
@@ -0,0 +1,5 @@
+
+class Analyzer:
+   def __init__(self, diag):
+      self.diag = diag
+
--- a/python/c3/astnodes.py	Fri Mar 01 11:43:52 2013 +0100
+++ b/python/c3/astnodes.py	Fri Mar 01 16:53:22 2013 +0100
@@ -17,22 +17,11 @@
                   children.append(mi)
       return children
 
-
-class Id(Node):
-   def __init__(self, tok, pub):
-      self.name = tok.val
-      self.is_public = pub
+class Designator(Node):
+   def __init__(self, tname):
+      self.tname = tname
    def __repr__(self):
-      return 'ID {0}'.format(self.name)
-
-# Selectors:
-class Designator(Node):
-   def __init__(self, obj, selectors, typ):
-      self.obj = obj
-      self.selectors = selectors
-      self.typ = typ
-   def __repr__(self):
-      return 'DESIGNATOR {0}, selectors {1}, type {2}'.format(self.obj, self.selectors, self.typ)
+      return 'DESIGNATOR {0}'.format(self.tname)
 
 """
 Type classes
@@ -109,13 +98,15 @@
    def __repr__(self):
       return 'VAR {0} : {1}'.format(self.name, self.typ)
 
-class Parameter(Node):
-   """ A parameter has a passing method, name and typ """
-   def __init__(self, name, typ):
+# Procedure types
+class Function(Symbol):
+   """ Actual implementation of a function """
+   def __init__(self, name, typ=None, block=None):
       self.name = name
+      self.body = block
       self.typ = typ
    def __repr__(self):
-      return 'PARAM {0} {1}'.format(self.name, self.typ)
+      return 'PROCEDURE {0} {1}'.format(self.name, self.typ)
 
 # Operations:
 class Unop(Node):
@@ -133,6 +124,12 @@
    def __repr__(self):
       return 'BINOP {0}'.format(self.op)
 
+class VariableUse(Node):
+   def __init__(self, target):
+      self.target = target
+   def __repr__(self):
+      return 'VAR USE {0}'.format(self.target)
+
 # Modules
 class Package(Node):
    def __init__(self, name):
@@ -140,16 +137,6 @@
    def __repr__(self):
       return 'PACKAGE {0}'.format(self.name)
 
-# Procedure types
-class Procedure(Symbol):
-   """ Actual implementation of a function """
-   def __init__(self, name, typ=None, block=None):
-      self.name = name
-      self.body = block
-      self.typ = typ
-   def __repr__(self):
-      return 'PROCEDURE {0} {1}'.format(self.name, self.typ)
-
 # Statements
 class CompoundStatement(Node):
    def __init__(self, statements):
--- a/python/c3/parser.py	Fri Mar 01 11:43:52 2013 +0100
+++ b/python/c3/parser.py	Fri Mar 01 16:53:22 2013 +0100
@@ -19,9 +19,6 @@
          self.diag.diag(e)
    def Error(self, msg):
       raise CompilerException(msg, self.token.loc)
-   def skipToSemiCol(self):
-      while not (self.Peak == ';' or self.Peak == 'END'):
-         self.NextToken()
    # Lexer helpers:
    def Consume(self, typ):
       if self.Peak == typ:
@@ -43,8 +40,7 @@
       return False
    def NextToken(self):
       t = self.token
-      if t.typ != 'END':
-         self.token = self.tokens.__next__()
+      if t.typ != 'END': self.token = self.tokens.__next__()
       return t
    def initLex(self, source):
       self.tokens = lexer.tokenize(source) # Lexical stage
@@ -72,7 +68,7 @@
    def parseDesignator(self):
       """ A designator designates an object """
       name = self.Consume('ID')
-      return name.val
+      return self.sema.actOnDesignator(name.val, name.loc)
 
    # Type system
    def parseType(self):
@@ -102,19 +98,20 @@
       self.Consume('(')
       parameters = []
       if not self.hasConsumed(')'):
-         typ = self.parseType()
-         name = self.Consume('ID')
-         parameters.append(astnodes.Parameter(name, typ))
-         while self.hasConsumed(','):
+         def parseParameter():
             typ = self.parseType()
             name = self.Consume('ID')
-            parameters.append(astnodes.Parameter(name, typ))
+            parameters.append(self.sema.actOnParameter(name.val, name.loc, typ))
+         parseParameter()
+         while self.hasConsumed(','):
+            parseParameter()
          self.Consume(')')
       body = self.parseCompoundStatement()
       self.sema.actOnFuncDef2(parameters, returntype, body)
 
    # Statements:
    def parseAssignment(self, lval):
+      lval = astnodes.VariableUse(lval)
       self.Consume('=')
       rval = self.parseExpression()
       return astnodes.Assignment(lval, rval)
@@ -194,7 +191,7 @@
          return self.sema.actOnNumber(val.val, val.loc)
       elif self.Peak == 'ID':
          d = self.parseDesignator()
-         return d
+         return self.sema.actOnVariableUse(d)
       self.Error('Expected NUM, ID or (expr), got {0}'.format(self.Peak))
 
    def parseBinopRhs(self, lhs, min_prec):
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/c3/scope.py	Fri Mar 01 16:53:22 2013 +0100
@@ -0,0 +1,33 @@
+from . import astnodes
+
+class Scope:
+   """ A scope contains all symbols in a scope """
+   def __init__(self, parent=None):
+      self.symbols = {}
+      self.parent = parent
+   def __iter__(self):
+      return iter(self.symbols.values())
+   def getSymbol(self, name):
+      if name in self.symbols:
+         return self.symbols[name]
+      # Look for symbol:
+      if self.parent:
+         return self.parent.getSymbol(name)
+      raise CompilerException("Symbol {0} not found".format(name), name.loc)
+   def hasSymbol(self, name):
+      if name in self.symbols:
+         return True
+      if self.parent:
+         return self.parent.hasSymbol(name)
+      return False
+   def addSymbol(self, sym):
+      self.symbols[sym.name] = sym
+
+def createBuiltins(scope):
+   for tn in ['int', 'u32', 'u16', 'double']:
+      scope.addSymbol(astnodes.BaseType(tn))
+
+topScope = Scope()
+
+createBuiltins(topScope)
+
--- a/python/c3/semantics.py	Fri Mar 01 11:43:52 2013 +0100
+++ b/python/c3/semantics.py	Fri Mar 01 16:53:22 2013 +0100
@@ -1,46 +1,49 @@
 from . import astnodes
-
-class Scope:
-   """ A scope contains all symbols in a scope """
-   def __init__(self, parent=None):
-      self.symbols = {}
-      self.parent = parent
-   def __iter__(self):
-      return iter(self.symbols.values())
-   def getType(self, name):
-      t = self.getSymbol(name)
-      print(t)
-      assert isinstance(t, astnodes.Type)
-      return t
-   def getSymbol(self, name):
-      if name in self.symbols:
-         return self.symbols[name]
-      # Look for symbol:
-      if self.parent:
-         return self.parent.getSymbol(name)
-      raise CompilerException("Symbol {0} not found".format(name), name.loc)
-   def hasSymbol(self, name):
-      if name in self.symbols:
-         return True
-      if self.parent:
-         return self.parent.hasSymbol(name)
-      return False
-      
-   def addSymbol(self, sym):
-      self.symbols[sym.name] = sym
-
-def createBuiltins(scope):
-   scope.addSymbol(astnodes.BaseType('int'))
+from .scope import Scope, topScope
+from ppci.errors import CompilerException
 
 class Semantics:
    """ This class constructs the AST from parser input """
    def __init__(self, diag):
       self.diag = diag
+   def addSymbol(self, s):
+      if self.curScope.hasSymbol(s.name):
+         msg = 'Redefinition of {0}'.format(s.name)
+         self.diag.diag(CompilerException(msg, s.loc))
+      else:
+         self.curScope.addSymbol(s)
    def handlePackage(self, name, loc):
       self.mod = astnodes.Package(name)
       self.mod.loc = loc
-      self.mod.scope = self.curScope = Scope()
-      createBuiltins(self.curScope)
+      self.mod.scope = self.curScope = Scope(topScope)
+   def actOnVarDef(self, name, loc, t, ival):
+      s = astnodes.Variable(name, t)
+      s.loc = loc
+      self.addSymbol(s)
+   def actOnFuncDef1(self, name, loc):
+      self.curFunc = astnodes.Function(name)
+      self.curFunc.loc = loc
+      self.addSymbol(self.curFunc)
+      self.curScope = self.curFunc.scope = Scope(self.curScope)
+   def actOnParameter(self, name, loc, t):
+      p = astnodes.Variable(name, t)
+      p.loc = loc
+      p.parameter = True
+      self.addSymbol(p)
+      return p
+   def actOnFuncDef2(self, parameters, returntype, body):
+      self.curFunc.body = body
+      self.curFunc.typ = astnodes.FunctionType(parameters, returntype)
+      self.curFunc = None
+      self.curScope = self.curScope.parent
+   def actOnType(self, tok):
+      # Try to lookup type, in case of failure return void
+      pass
+   def actOnDesignator(self, tname, loc):
+      d = astnodes.Designator(tname)
+      d.scope = self.curScope
+      d.loc = loc
+      return d
    def actOnBinop(self, lhs, op, rhs, loc):
       bo = astnodes.Binop(lhs, op, rhs)
       bo.loc = loc
@@ -49,21 +52,6 @@
       n = astnodes.Constant(num)
       n.loc = loc
       return n
-   def actOnVarDef(self, name, loc, t, ival):
-      s = astnodes.Variable(name, t)
-      s.loc = loc
-      self.curScope.addSymbol(s)
-   def actOnFuncDef1(self, name, loc):
-      self.curFunc = astnodes.Procedure(name)
-      self.curFunc.loc = loc
-      self.curScope.addSymbol(self.curFunc)
-      self.curScope = self.curFunc.scope = Scope(self.curScope)
-   def actOnFuncDef2(self, parameters, returntype, body):
-      self.curFunc.body = body
-      self.curFunc.typ = astnodes.FunctionType(parameters, returntype)
-      self.curFunc = None
-      self.curScope = self.curScope.parent
-   def actOnType(self, tok):
-      # Try to lookup type, in case of failure return void
-      pass
+   def actOnVariableUse(self, d):
+      return astnodes.VariableUse(d)
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/c3/typecheck.py	Fri Mar 01 16:53:22 2013 +0100
@@ -0,0 +1,62 @@
+from .astnodes import BaseType, Variable, Designator, Function
+from .astnodes import CompoundStatement, Assignment, VariableUse
+from .astnodes import Binop, Unop, Constant
+from .astnodes import IfStatement, WhileStatement, ReturnStatement
+from .astnodes import FunctionType, BaseType
+from . import astnodes
+from ppci.errors import CompilerException
+from .scope import topScope
+
+class TypeChecker:
+   def __init__(self, diag):
+      self.diag = diag
+   def err(self, msg, loc):
+      self.diag.diag(CompilerException(msg, loc))
+   def checkPackage(self, pkg):
+      for s in pkg.scope:
+         self.check(s)
+   def resolveDesignator(self, d):
+      if d.scope.hasSymbol(d.tname):
+         return d.scope.getSymbol(d.tname)
+      else:
+         msg = 'Cannot resolve name {0}'.format(d.tname)
+         self.err(msg, d.loc)
+   def check(self, sym):
+      if type(sym) is Variable:
+         if type(sym.typ) is Designator:
+            sym.typ = self.resolveDesignator(sym.typ)
+      elif type(sym) is Function:
+         for s in sym.scope:
+            self.check(s)
+         self.check(sym.typ)
+         self.check(sym.body)
+      elif type(sym) is CompoundStatement:
+         for s in sym.statements:
+            self.check(s)
+      elif type(sym) is IfStatement:
+         self.check(sym.condition)
+         print(sym.condition)
+         self.check(sym.truestatement)
+         self.check(sym.falsestatement)
+      elif type(sym) is VariableUse:
+         if type(sym.target) is Designator:
+            sym.target = self.resolveDesignator(sym.target)
+      elif type(sym) is Assignment:
+         self.check(sym.lval)
+         self.check(sym.rval)
+      elif type(sym) is ReturnStatement:
+         self.check(sym.expr)
+      elif type(sym) is Constant:
+         if type(sym.value) is int:
+            sym.typ = topScope.getSymbol('int')
+      elif type(sym) is FunctionType:
+         if type(sym.returntype) is Designator:
+            sym.returntype = self.resolveDesignator(sym.returntype)
+         self.check(sym.returntype)
+      elif type(sym) is Binop:
+         self.check(sym.a)
+         self.check(sym.b)
+         if type(sym.a) is Constant and type(sym.b) is Constant:
+            # Possibly fold expression
+            pass
+
--- a/python/testc3.py	Fri Mar 01 11:43:52 2013 +0100
+++ b/python/testc3.py	Fri Mar 01 16:53:22 2013 +0100
@@ -1,11 +1,10 @@
-import c3.parser, c3.semantics
+import c3
 from ppci.errors import printError, Diagnostics
 import time
 
 testsrc = """
 package test;
 
-var u32 a ;
 var u32 c, d;
 
 function void test1() 
@@ -13,12 +12,13 @@
     var u32 b;
     var int a = 10;
     b = 20;
-    var int buf;
+    var int88 buf;
     var int i;
     i = 2;
+    zero = i - 2;
     if (i > 1)
     {
-       buf = b + 22 * i - 13 + (55 * 2 *9-2) / 44
+       buf = b + 22 * i - 13 + (55 * 2 *9-2) / 44 - one
     }
 }
 
@@ -28,10 +28,12 @@
    a = 2
 }
 
+var u8 hahaa = 23 * 2;
+
 function int t2(u32 a, u32 b)
 {
    return a + b;
-   a = 2 + 33 - 1;
+   a = 2 + 33 * 1 - 3;
    a = a - 2
 }
 
@@ -42,7 +44,7 @@
    if isinstance(ast, c3.astnodes.Package):
       for s in ast.scope:
          printAst(s, indent + '  ')
-   if isinstance(ast, c3.astnodes.Procedure):
+   if isinstance(ast, c3.astnodes.Function):
       for s in ast.scope:
          printAst(s, indent + '  ')
    for c in ast.getChildren():
@@ -53,15 +55,22 @@
    print(testsrc)
    print('[1] parsing')
    diag = Diagnostics()
-   sema = c3.semantics.Semantics(diag)
-   p = c3.parser.Parser(sema, diag)
+   sema = c3.Semantics(diag)
+   p = c3.Parser(sema, diag)
+   tc = c3.TypeChecker(diag)
+   al = c3.Analyzer(diag)
    t1 = time.time()
    p.parseSource(testsrc)
    t2 = time.time() 
    print('parsetime: {0} [s]'.format(t2 - t1))
+   t2 = time.time() 
+   tc.checkPackage(sema.mod)
+   t3 = time.time() 
+   print('checktime: {0} [s]'.format(t3 - t2))
+   print('{0} errors'.format(len(diag.diags)))
 
    for d in diag.diags:
-      print('ERROR:', d)
+      print('ERROR:')
       printError(testsrc, d)
    print('[2] ast:')
    printAst(sema.mod)