view python/c3/typecheck.py @ 220:3f6c30a5d234

Major change in expression parsing to enable pointers and structs
author Windel Bouwman
date Sat, 06 Jul 2013 21:32:20 +0200
parents c1ccb1cb4cef
children 848c4b15fd0b
line wrap: on
line source

from .astnodes import *
from .scope import *
from .visitor import Visitor

def equalTypes(a, b):
   """ Compare types a and b for equality. Not equal until proven otherwise. """
   if type(a) is type(b):
      if type(a) is BaseType:
         return a.name == b.name
      elif type(a) is PointerType:
        return equalTypes(a.ptype, b.ptype)
   return False

class TypeChecker:
    def __init__(self, diag):
        self.diag = diag

    def error(self, msg, loc):
        """ Wrapper that registers the message and marks the result invalid """
        self.diag.error(msg, loc)
        self.ok = False

    def checkPackage(self, pkg):
        self.ok = True
        visitor = Visitor()
        visitor.visit(pkg, f_post=self.check2)
        return self.ok

    def check2(self, sym):
      if type(sym) in [IfStatement, WhileStatement]:
         if not equalTypes(sym.condition.typ, boolType):
            self.error('Condition must be of type {0}'.format(boolType), sym.condition.loc)
      elif type(sym) is Assignment:
         if type(sym.lval.typ) is PointerType and sym.rval.typ == intType:
            print('special case, int to pointer is ok for now')
            # TODO: add cast instruction?
         elif not equalTypes(sym.lval.typ, sym.rval.typ):
            self.error('Cannot assign {0} to {1}'.format(sym.rval.typ, sym.lval.typ), sym.loc)
      elif type(sym) is ReturnStatement:
         pass
      elif type(sym) is FunctionCall:
         if sym.proc:
            # Check arguments:
            ngiv = len(sym.args)
            ptypes = sym.proc.typ.parametertypes
            nreq = len(ptypes)
            if ngiv != nreq:
               self.error('Function {2}: {0} arguments required, {1} given'.format(nreq, ngiv, sym.proc.name), sym.loc)
            else:
               for a, at in zip(sym.args, ptypes):
                  if not equalTypes(a.typ, at):
                     self.error('Got {0}, expected {1}'.format(a.typ, at), a.loc)
            # determine return type:
            sym.typ = sym.proc.typ.returntype
         else:
            sym.typ = intType
      elif type(sym) is VariableUse:
         if sym.target:
            sym.typ = sym.target.typ
         else:
            sym.typ = intType
      elif type(sym) is Literal:
         if type(sym.val) is int:
            sym.typ = intType
         elif type(sym.val) is float:
            sym.typ = doubleType
         elif type(sym.val) is bool:
            sym.typ = boolType
         else:
            self.error('Unknown literal type', sym.loc)
      elif type(sym) is Unop:
            if sym.op == '&':
                sym.typ = PointerType(sym.a.typ)
            elif sym.op == '*':
                # pointer deref
                if type(sym.a.typ) is PointerType:
                    sym.typ = sym.a.typ.ptype
                else:
                    self.error('Cannot dereference non-pointer type {}'.format(sym.a.typ), sym.loc)
            else:
                print('unknown unop', sym.op)
      elif type(sym) is Binop:
         if sym.op in ['+', '-', '*', '/']:
            if equalTypes(sym.a.typ, sym.b.typ):
               if equalTypes(sym.a.typ, intType):
                  sym.typ = sym.a.typ
               else:
                  self.error('Can only add integers', sym.loc)
                  sym.typ = intType
            else:
               # assume void here? TODO: throw exception!
               sym.typ = intType
               self.error('Types unequal', sym.loc)
         elif sym.op in ['>', '<', '==', '<=', '>=']:
            sym.typ = boolType
            if not equalTypes(sym.a.typ, sym.b.typ):
               self.error('Types unequal', sym.loc)
         elif sym.op in ['or', 'and']:
            sym.typ = boolType
            if not equalTypes(sym.a.typ, boolType):
               self.error('Must be {0}'.format(boolType), sym.a.loc)
            if not equalTypes(sym.b.typ, boolType):
               self.error('Must be {0}'.format(boolType), sym.b.loc)
         else:
            sym.typ = voidType
            print('unknown binop', sym.op)
      elif type(sym) is Variable:
         # check initial value type:
         # TODO
         pass
      elif type(sym) is Constant:
         if not equalTypes(sym.typ, sym.value.typ):
            self.error('Cannot assign {0} to {1}'.format(sym.value.typ, sym.typ), sym.loc)
      elif type(sym) in [EmptyStatement, CompoundStatement, Package, Function, FunctionType]:
         pass
      else:
            raise Exception('Unknown type check {0}'.format(sym))