view python/c3/typecheck.py @ 229:51d5ed1bd503

Added testrunner
author Windel Bouwman
date Sat, 13 Jul 2013 11:13:01 +0200
parents 7f18ed9b6b7e
children 88a1e0baef65
line wrap: on
line source

from .astnodes import *
from .scope import *
from .visitor import Visitor

def equalTypes(a, b):
    """ 
        Compare types a and b for equality.
        Not equal until proven otherwise.
    """
    # Recurse into named types:
    if type(a) is DefinedType:
        return equalTypes(a.typ, b)
    if type(b) is DefinedType:
        return equalTypes(a, b.typ)
    # Compare for structural equivalence:
    if type(a) is type(b):
        if type(a) is BaseType:
            return a.name == b.name
        elif type(a) is PointerType:
            return equalTypes(a.ptype, b.ptype)
        elif type(a) is StructureType:
            if len(a.mems) != len(b.mems):
                return False
            for amem, bmem in zip(a.mems, b.mems):
                if not equalTypes(amem.typ, bmem.typ):
                    return False
            return True
        else:
            raise Exception('Type compare not implemented')
    return False

def canCast(fromT, toT):
    if isinstance(fromT, PointerType) and isinstance(toT, PointerType):
        return True
    elif fromT is intType and isinstance(toT, PointerType):
        return True
    return False

class TypeChecker:
    def __init__(self, diag):
        self.diag = diag

    def error(self, msg, loc):
        """ 
            Wrapper that registers the message and marks the result invalid 
        """
        self.diag.error(msg, loc)
        self.ok = False

    def checkPackage(self, pkg):
        self.ok = True
        visitor = Visitor()
        visitor.visit(pkg, f_post=self.check2)
        return self.ok

    def check2(self, sym):
        if type(sym) in [IfStatement, WhileStatement]:
            if not equalTypes(sym.condition.typ, boolType):
                self.error('Condition must be of type {0}'.format(boolType), sym.condition.loc)
        elif type(sym) is Assignment:
            if not equalTypes(sym.lval.typ, sym.rval.typ):
                self.error('Cannot assign {0} to {1}'.format(sym.rval.typ, sym.lval.typ), sym.loc)
            if not sym.lval.lvalue:
                self.error('No valid lvalue {}'.format(sym.lval), sym.lval.loc)
            #if sym.rval.lvalue:
            #    self.error('Right hand side must be an rvalue', sym.rval.loc)
        elif type(sym) is ReturnStatement:
            pass
        elif type(sym) is FunctionCall:
            # Check arguments:
            ngiv = len(sym.args)
            ptypes = sym.proc.typ.parametertypes
            nreq = len(ptypes)
            if ngiv != nreq:
               self.error('Function {2}: {0} arguments required, {1} given'.format(nreq, ngiv, sym.proc.name), sym.loc)
            else:
               for a, at in zip(sym.args, ptypes):
                  if not equalTypes(a.typ, at):
                     self.error('Got {0}, expected {1}'.format(a.typ, at), a.loc)
            # determine return type:
            sym.typ = sym.proc.typ.returntype
        elif type(sym) is VariableUse:
            sym.lvalue = True
            if type(sym.target) is Variable:
                sym.typ = sym.target.typ
            else:
                print('warning {} has no target, defaulting to int'.format(sym))
                sym.typ = intType
        elif type(sym) is Literal:
            sym.lvalue = False
            if type(sym.val) is int:
                sym.typ = intType
            elif type(sym.val) is float:
                sym.typ = doubleType
            elif type(sym.val) is bool:
                sym.typ = boolType
            else:
                raise Exception('Unknown literal type'.format(sym.val))
        elif type(sym) is Unop:
            if sym.op == '&':
                sym.typ = PointerType(sym.a.typ)
                sym.lvalue = False
            else:
                raise Exception('Unknown unop {0}'.format(sym.op))
        elif type(sym) is Deref:
            # pointer deref
            sym.lvalue = True
            # check if the to be dereferenced variable is a pointer type:
            ptype = sym.ptr.typ
            if type(ptype) is DefinedType:
                ptype = ptype.typ
            if type(ptype) is PointerType:
                sym.typ = ptype.ptype
            else:
                self.error('Cannot dereference non-pointer type {}'.format(ptype), sym.loc)
                sym.typ = intType
        elif type(sym) is FieldRef:
            basetype = sym.base.typ
            sym.lvalue = True
            if type(basetype) is DefinedType:
                basetype = basetype.typ
            if type(basetype) is StructureType:
                if basetype.hasField(sym.field):
                    sym.typ = basetype.fieldType(sym.field)
                else:
                    self.error('{} does not contain field {}'.format(basetype, sym.field), sym.loc)
                    sym.typ = intType
            else:
                self.error('Cannot select field {} of non-structure type {}'.format(sym.field, basetype), sym.loc)
                sym.typ = intType
        elif type(sym) is Binop:
            sym.lvalue = False
            if sym.op in ['+', '-', '*', '/']:
                if equalTypes(sym.a.typ, sym.b.typ):
                   if equalTypes(sym.a.typ, intType):
                      sym.typ = sym.a.typ
                   else:
                      self.error('Can only add integers', sym.loc)
                      sym.typ = intType
                else:
                   # assume void here? TODO: throw exception!
                   sym.typ = intType
                   self.error('Types unequal {} != {}'.format(sym.a.typ, sym.b.typ), sym.loc)
            elif sym.op in ['>', '<', '==', '<=', '>=']:
                sym.typ = boolType
                if not equalTypes(sym.a.typ, sym.b.typ):
                   self.error('Types unequal {} != {}'.format(sym.a.typ, sym.b.typ), sym.loc)
            elif sym.op in ['or', 'and']:
                sym.typ = boolType
                if not equalTypes(sym.a.typ, boolType):
                   self.error('Must be {0}'.format(boolType), sym.a.loc)
                if not equalTypes(sym.b.typ, boolType):
                   self.error('Must be {0}'.format(boolType), sym.b.loc)
            elif sym.op in ['|', '&']:
                sym.typ = intType
                sym.lvalue = False
                if equalTypes(sym.a.typ, sym.b.typ):
                    if not equalTypes(sym.a.typ, intType):
                        self.error('Can only add integers', sym.loc)
                else:
                   self.error('Types unequal {} != {}'.format(sym.a.typ, sym.b.typ), sym.loc)
            else:
                raise Exception('Unknown binop {0}'.format(sym.op))
        elif type(sym) is Variable:
         # check initial value type:
         # TODO
         pass
        elif type(sym) is TypeCast:
            if canCast(sym.a.typ, sym.to_type):
                sym.typ = sym.to_type
            else:
                self.error('Cannot cast {} to {}'.format(sym.a.typ, sym.to_type))
        elif type(sym) is Constant:
            if not equalTypes(sym.typ, sym.value.typ):
                self.error('Cannot assign {0} to {1}'.format(sym.value.typ, sym.typ), sym.loc)
        elif type(sym) in [CompoundStatement, Package, Function, FunctionType, ExpressionStatement, DefinedType]:
         pass
        else:
            raise Exception('Unknown type check {0}'.format(sym))