view python/c3/parser.py @ 276:56d37ed4b4d2

phaa
author Windel Bouwman
date Mon, 16 Sep 2013 21:51:17 +0200
parents e64bae57cda8
children 02385f62f250
line wrap: on
line source

import logging
from . import astnodes, lexer
from ppci import CompilerError

class Parser:
    """ Parses sourcecode into an abstract syntax tree (AST) """
    def __init__(self, diag):
        self.logger = logging.getLogger('c3')
        self.diag = diag

    def parseSource(self, source):
        self.logger.info('Parsing source')
        self.initLex(source)
        try:
            self.parsePackage()
            return self.mod
        except CompilerError as e:
            self.diag.addDiag(e)

    def Error(self, msg):
        raise CompilerError(msg, self.token.loc)

    # Lexer helpers:
    def Consume(self, typ):
        if self.Peak == typ:
            return self.NextToken()
        else:
            self.Error('Excected: "{0}", got "{1}"'.format(typ, self.Peak))

    @property
    def Peak(self):
        return self.token.typ

    @property
    def CurLoc(self):
        return self.token.loc

    def hasConsumed(self, typ):
        if self.Peak == typ:
            self.Consume(typ)
            return True
        return False

    def NextToken(self):
        t = self.token
        if t.typ != 'END':
            self.token = self.tokens.__next__()
        return t

    def initLex(self, source):
        self.tokens = lexer.tokenize(source) # Lexical stage
        self.token = self.tokens.__next__()

    def addDeclaration(self, decl):
        self.currentPart.declarations.append(decl)
    
    def parseImport(self):
        self.Consume('import')
        name = self.Consume('ID').val
        self.mod.imports.append(name)
        self.Consume(';')

    def parsePackage(self):
        self.Consume('package')
        name = self.Consume('ID')
        self.Consume(';')
        self.mod = astnodes.Package(name.val, name.loc)
        self.currentPart = self.mod
        while self.Peak != 'END':
            self.parseTopLevel()
        self.Consume('END')

    def parseTopLevel(self):
        if self.Peak == 'function':
            self.parseFunctionDef()
        elif self.Peak == 'var':
            self.parseVarDef()
        elif self.Peak == 'const':
            self.parseConstDef()
        elif self.Peak == 'type':
            self.parseTypeDef()
        elif self.Peak == 'import':
            self.parseImport()
        else:
            self.Error('Expected function, var, const or type')

    def parseDesignator(self):
        """ A designator designates an object """
        name = self.Consume('ID')
        return astnodes.Designator(name.val, name.loc)

    # Type system
    def parseTypeSpec(self):
        # For now, do simple type spec, just parse an ID:
        #return self.parseDesignator()
        if self.Peak == 'struct':
            self.Consume('struct')
            self.Consume('{')
            mems = []
            while self.Peak != '}':
                mem_t = self.parseTypeSpec()
                mem_n = self.Consume('ID').val
                mems.append(astnodes.StructField(mem_n, mem_t))
                while self.hasConsumed(','):
                    mem_n = self.Consume('ID').val
                    mems.append(astnodes.StructField(mem_n, mem_t))
                self.Consume(';')
            self.Consume('}')
            theT = astnodes.StructureType(mems)
        else:
            theT = self.parseDesignator()
        # Check for pointer suffix:
        while self.hasConsumed('*'):
            theT = astnodes.PointerType(theT)
        return theT

    def parseTypeDef(self):
        self.Consume('type')
        newtype = self.parseTypeSpec()
        typename = self.Consume('ID')
        self.Consume(';')
        df = astnodes.DefinedType(typename.val, newtype, typename.loc)
        self.addDeclaration(df)

    # Variable declarations:
    def parseVarDef(self):
        self.Consume('var')
        t = self.parseTypeSpec()
        def parseVar():
            name = self.Consume('ID')
            v = astnodes.Variable(name.val, t)
            v.loc = name.loc
            if self.hasConsumed('='):
                v.ival = self.Expression()
            self.addDeclaration(v)
        parseVar()
        while self.hasConsumed(','):
            parseVar()
        self.Consume(';')

    def parseConstDef(self):
      self.Consume('const')
      t = self.parseTypeSpec()
      def parseConst():
         name = self.Consume('ID')
         self.Consume('=')
         val = self.Expression()
         c = astnodes.Constant(name.val, t, val)
         c.loc = name.loc
      parseConst()
      while self.hasConsumed(','):
         parseConst()
      self.Consume(';')
      
    # Procedures
    def parseFunctionDef(self):
      loc = self.Consume('function').loc
      returntype = self.parseTypeSpec()
      fname = self.Consume('ID').val
      f = astnodes.Function(fname, loc)
      self.addDeclaration(f)
      savePart = self.currentPart
      self.currentPart = f
      self.Consume('(')
      parameters = []
      if not self.hasConsumed(')'):
            def parseParameter():
                typ = self.parseTypeSpec()
                name = self.Consume('ID')
                param = astnodes.FormalParameter(name.val, typ)
                param.loc = name.loc
                self.addDeclaration(param)
                parameters.append(param)
            parseParameter()
            while self.hasConsumed(','):
                parseParameter()
            self.Consume(')')
      paramtypes = [p.typ for p in parameters]
      f.typ = astnodes.FunctionType(paramtypes, returntype)
      f.body = self.parseCompoundStatement()
      self.currentPart = savePart

    # Statements:

    def parseIfStatement(self):
        loc = self.Consume('if').loc
        self.Consume('(')
        condition = self.Expression()
        self.Consume(')')
        yes = self.parseCompoundStatement()
        if self.hasConsumed('else'):
            no = self.parseCompoundStatement()
        else:
            no = None
        return astnodes.IfStatement(condition, yes, no, loc)

    def parseWhileStatement(self):
        loc = self.Consume('while').loc
        self.Consume('(')
        condition = self.Expression()
        self.Consume(')')
        statements = self.parseCompoundStatement()
        return astnodes.WhileStatement(condition, statements, loc)

    def parseReturnStatement(self):
        loc = self.Consume('return').loc
        expr = self.Expression()
        self.Consume(';')
        return astnodes.ReturnStatement(expr, loc)

    def parseCompoundStatement(self):
        self.Consume('{')
        statements = []
        while not self.hasConsumed('}'):
            s = self.Statement()
            if s is None:
                continue
            statements.append(s)
        return astnodes.CompoundStatement(statements)

    def Statement(self):
        # Determine statement type based on the pending token:
        if self.Peak == 'if':
            return self.parseIfStatement()
        elif self.Peak == 'while':
            return self.parseWhileStatement()
        elif self.Peak == '{':
            return self.parseCompoundStatement()
        elif self.hasConsumed(';'):
            pass
        elif self.Peak == 'var':
            self.parseVarDef()
        elif self.Peak == 'return':
            return self.parseReturnStatement()
        else:
            return self.AssignmentOrCall()

    def AssignmentOrCall(self):
        x = self.UnaryExpression()
        if self.Peak == '=':
            # We enter assignment mode here.
            loc = self.Consume('=').loc
            rhs = self.Expression()
            return astnodes.Assignment(x, rhs, loc)
        else:
            return astnodes.ExpressionStatement(x, x.loc)

    # Expression section:
    # We not implement these C constructs:
    # a(2), f = 2
    # and this:
    # a = 2 < x : 4 ? 1;

    def Expression(self):
        exp = self.LogicalAndExpression()
        while self.Peak == 'or':
            loc = self.Consume('or').loc
            e2 = self.LogicalAndExpression()
            exp = astnodes.Binop(exp, 'or', e2, loc)
        return exp

    def LogicalAndExpression(self):
        o = self.EqualityExpression()
        while self.Peak == 'and':
            loc = self.Consume('and').loc
            o2 = self.EqualityExpression()
            o = astnodes.Binop(o, 'and', o2, loc)
        return o

    def EqualityExpression(self):
        ee = self.SimpleExpression()
        while self.Peak in ['<', '==', '>']:
            op = self.Consume(self.Peak)
            ee2 = self.SimpleExpression()
            ee = astnodes.Binop(ee, op.typ, ee2, op.loc)
        return ee

    def SimpleExpression(self):
        """ Shift operations before + and - ? """
        e = self.AddExpression()
        while self.Peak in ['>>', '<<']:
            op = self.Consume(self.Peak)
            e2 = self.AddExpression()
            e = astnodes.Binop(e, op.typ, e2, op.loc)
        return e

    def AddExpression(self):
        e = self.Term()
        while self.Peak in ['+', '-']:
            op = self.Consume(self.Peak)
            e2 = self.Term()
            e = astnodes.Binop(e, op.typ, e2, op.loc)
        return e

    def Term(self):
        t = self.BitwiseOr()
        while self.Peak in ['*', '/']:
            op = self.Consume(self.Peak)
            t2 = self.BitwiseOr()
            t = astnodes.Binop(t, op.typ, t2, op.loc)
        return t

    def BitwiseOr(self):
        a = self.BitwiseAnd()
        while self.Peak in ['|']:
            op = self.Consume(self.Peak)
            b = self.BitwiseAnd()
            a = astnodes.Binop(a, op.typ, b, op.loc)
        return a

    def BitwiseAnd(self):
        a = self.CastExpression()
        while self.Peak in ['&']:
            op = self.Consume(self.Peak)
            b = self.CastExpression()
            a = astnodes.Binop(a, op.typ, b, op.loc)
        return a

    # Domain of unary expressions:

    def CastExpression(self):
        """
          the C-style type cast conflicts with '(' expr ')'
          so introduce extra keyword 'cast'
        """
        if self.Peak == 'cast':
            loc = self.Consume('cast').loc
            self.Consume('<')
            t = self.parseTypeSpec()
            self.Consume('>')
            self.Consume('(')
            ce = self.Expression()
            self.Consume(')')
            return astnodes.TypeCast(t, ce, loc)
        else:
            return self.UnaryExpression()

    def UnaryExpression(self):
        if self.Peak in ['&', '*']:
            op = self.Consume(self.Peak)
            ce = self.CastExpression()
            if op.val == '*':
                return astnodes.Deref(ce, op.loc)
            else:
                return astnodes.Unop(op.typ, ce, op.loc)
        else:
            return self.PostFixExpression()

    def PostFixExpression(self):
        pfe = self.PrimaryExpression()
        while self.Peak in ['[', '(', '.', '->']:
            if self.hasConsumed('['):
                pass
            elif self.hasConsumed('('):
                # Function call
                args = []
                if not self.hasConsumed(')'):
                    args.append(self.Expression())
                    while self.hasConsumed(','):
                        args.append(self.Expression())
                    self.Consume(')')
                pfe = astnodes.FunctionCall(pfe, args, pfe.loc)
            elif self.hasConsumed('->'):
                field = self.Consume('ID')
                pfe = astnodes.Deref(pfe, pfe.loc)
                pfe = astnodes.FieldRef(pfe, field.val, field.loc)
            elif self.hasConsumed('.'):
                field = self.Consume('ID')
                pfe = astnodes.FieldRef(pfe, field.val, field.loc)
            else:
                raise Exception()
        return pfe

    def PrimaryExpression(self):
        if self.hasConsumed('('):
            e = self.Expression()
            self.Consume(')')
            return e
        elif self.Peak == 'NUMBER':
            val = self.Consume('NUMBER')
            return astnodes.Literal(val.val, val.loc)
        elif self.Peak == 'REAL':
            val = self.Consume('REAL')
            return astnodes.Literal(val.val, val.loc)
        elif self.Peak == 'true':
            val = self.Consume('true')
            return astnodes.Literal(True, val.loc)
        elif self.Peak == 'false':
            val = self.Consume('false')
            return astnodes.Literal(False, val.loc)
        elif self.Peak == 'ID':
            d = self.parseDesignator()
            return astnodes.VariableUse(d, d.loc)
        self.Error('Expected NUM, ID or (expr), got {0}'.format(self.Peak))