view python/ppci/c3/parser.py @ 305:0615b5308710

Updated docs
author Windel Bouwman
date Fri, 06 Dec 2013 13:50:38 +0100
parents 6753763d3bec
children b145f8e6050b
line wrap: on
line source

import logging
from ppci import CompilerError
from .lexer import Lexer
from .astnodes import FieldRef, Literal, TypeCast, Unop, Binop
from .astnodes import Assignment, ExpressionStatement, CompoundStatement
from .astnodes import ReturnStatement, WhileStatement, IfStatement
from .astnodes import FunctionType, Function, FormalParameter
from .astnodes import StructureType, DefinedType, PointerType
from .astnodes import Constant, Variable
from .astnodes import StructField, Deref
from .astnodes import Package, ImportDesignator
from .astnodes import Designator, VariableUse, FunctionCall


class Parser:
    """ Parses sourcecode into an abstract syntax tree (AST) """
    def __init__(self, diag):
        self.logger = logging.getLogger('c3')
        self.diag = diag
        self.lexer = Lexer(diag)

    def parseSource(self, source):
        self.logger.info('Parsing source')
        self.initLex(source)
        try:
            self.parsePackage()
            return self.mod
        except CompilerError as e:
            self.diag.addDiag(e)

    def Error(self, msg):
        raise CompilerError(msg, self.token.loc)

    # Lexer helpers:
    def Consume(self, typ):
        if self.Peak == typ:
            return self.NextToken()
        else:
            self.Error('Excected: "{0}", got "{1}"'.format(typ, self.Peak))

    @property
    def Peak(self):
        return self.token.typ

    @property
    def CurLoc(self):
        return self.token.loc

    def hasConsumed(self, typ):
        if self.Peak == typ:
            self.Consume(typ)
            return True
        return False

    def NextToken(self):
        t = self.token
        if t.typ != 'END':
            self.token = self.tokens.__next__()
        return t

    def initLex(self, source):
        self.tokens = self.lexer.tokenize(source)
        self.token = self.tokens.__next__()

    def addDeclaration(self, decl):
        self.currentPart.declarations.append(decl)

    def parseImport(self):
        self.Consume('import')
        name = self.Consume('ID').val
        self.mod.imports.append(name)
        self.Consume(';')

    def parsePackage(self):
        self.Consume('module')
        name = self.Consume('ID')
        self.Consume(';')
        self.mod = Package(name.val, name.loc)
        self.currentPart = self.mod
        while self.Peak != 'END':
            self.parseTopLevel()
        self.Consume('END')

    def parseTopLevel(self):
        if self.Peak == 'function':
            self.parseFunctionDef()
        elif self.Peak == 'var':
            self.parseVarDef()
        elif self.Peak == 'const':
            self.parseConstDef()
        elif self.Peak == 'type':
            self.parseTypeDef()
        elif self.Peak == 'import':
            self.parseImport()
        else:
            self.Error('Expected function, var, const or type')

    def parseDesignator(self):
        """ A designator designates an object with a name. """
        name = self.Consume('ID')
        if self.hasConsumed(':'):
            name2 = self.Consume('ID')
            return ImportDesignator(name.val, name2.val, name.loc)
        else:
            return Designator(name.val, name.loc)

    # Type system
    def parseTypeSpec(self):
        # For now, do simple type spec, just parse an ID:
        if self.Peak == 'struct':
            self.Consume('struct')
            self.Consume('{')
            mems = []
            while self.Peak != '}':
                mem_t = self.parseTypeSpec()
                mem_n = self.Consume('ID').val
                mems.append(StructField(mem_n, mem_t))
                while self.hasConsumed(','):
                    mem_n = self.Consume('ID').val
                    mems.append(StructField(mem_n, mem_t))
                self.Consume(';')
            self.Consume('}')
            theT = StructureType(mems)
        else:
            theT = self.parseDesignator()
        # Check for pointer suffix:
        while self.hasConsumed('*'):
            theT = PointerType(theT)
        return theT

    def parseTypeDef(self):
        self.Consume('type')
        newtype = self.parseTypeSpec()
        typename = self.Consume('ID')
        self.Consume(';')
        df = DefinedType(typename.val, newtype, typename.loc)
        self.addDeclaration(df)

    # Variable declarations:
    def parseVarDef(self):
        self.Consume('var')
        t = self.parseTypeSpec()

        def parseVar():
            name = self.Consume('ID')
            v = Variable(name.val, t)
            v.loc = name.loc
            if self.hasConsumed('='):
                v.ival = self.Expression()
            self.addDeclaration(v)
        parseVar()
        while self.hasConsumed(','):
            parseVar()
        self.Consume(';')

    def parseConstDef(self):
        self.Consume('const')
        t = self.parseTypeSpec()

        def parseConst():
            name = self.Consume('ID')
            self.Consume('=')
            val = self.Expression()
            c = Constant(name.val, t, val)
            c.loc = name.loc
        parseConst()
        while self.hasConsumed(','):
            parseConst()
        self.Consume(';')

    # Procedures
    def parseFunctionDef(self):
        loc = self.Consume('function').loc
        returntype = self.parseTypeSpec()
        fname = self.Consume('ID').val
        f = Function(fname, loc)
        self.addDeclaration(f)
        savePart = self.currentPart
        self.currentPart = f
        self.Consume('(')
        parameters = []
        if not self.hasConsumed(')'):
            def parseParameter():
                typ = self.parseTypeSpec()
                name = self.Consume('ID')
                param = FormalParameter(name.val, typ)
                param.loc = name.loc
                self.addDeclaration(param)
                parameters.append(param)
            parseParameter()
            while self.hasConsumed(','):
                parseParameter()
            self.Consume(')')
        paramtypes = [p.typ for p in parameters]
        f.typ = FunctionType(paramtypes, returntype)
        f.body = self.parseCompoundStatement()
        self.currentPart = savePart

    # Statements:

    def parseIfStatement(self):
        loc = self.Consume('if').loc
        self.Consume('(')
        condition = self.Expression()
        self.Consume(')')
        yes = self.parseCompoundStatement()
        if self.hasConsumed('else'):
            no = self.parseCompoundStatement()
        else:
            no = None
        return IfStatement(condition, yes, no, loc)

    def parseWhileStatement(self):
        loc = self.Consume('while').loc
        self.Consume('(')
        condition = self.Expression()
        self.Consume(')')
        statements = self.parseCompoundStatement()
        return WhileStatement(condition, statements, loc)

    def parseReturnStatement(self):
        loc = self.Consume('return').loc
        if self.Peak == ';':
            expr = Literal(0, loc)
        else:
            expr = self.Expression()
        self.Consume(';')
        return ReturnStatement(expr, loc)

    def parseCompoundStatement(self):
        self.Consume('{')
        statements = []
        while not self.hasConsumed('}'):
            s = self.Statement()
            if s is None:
                continue
            statements.append(s)
        return CompoundStatement(statements)

    def Statement(self):
        # Determine statement type based on the pending token:
        if self.Peak == 'if':
            return self.parseIfStatement()
        elif self.Peak == 'while':
            return self.parseWhileStatement()
        elif self.Peak == '{':
            return self.parseCompoundStatement()
        elif self.hasConsumed(';'):
            pass
        elif self.Peak == 'var':
            self.parseVarDef()
        elif self.Peak == 'return':
            return self.parseReturnStatement()
        else:
            return self.AssignmentOrCall()

    def AssignmentOrCall(self):
        x = self.UnaryExpression()
        if self.Peak == '=':
            # We enter assignment mode here.
            loc = self.Consume('=').loc
            rhs = self.Expression()
            return Assignment(x, rhs, loc)
        else:
            return ExpressionStatement(x, x.loc)

    # Expression section:
    # We not implement these C constructs:
    # a(2), f = 2
    # and this:
    # a = 2 < x : 4 ? 1;

    def Expression(self):
        exp = self.LogicalAndExpression()
        while self.Peak == 'or':
            loc = self.Consume('or').loc
            e2 = self.LogicalAndExpression()
            exp = Binop(exp, 'or', e2, loc)
        return exp

    def LogicalAndExpression(self):
        o = self.EqualityExpression()
        while self.Peak == 'and':
            loc = self.Consume('and').loc
            o2 = self.EqualityExpression()
            o = Binop(o, 'and', o2, loc)
        return o

    def EqualityExpression(self):
        ee = self.SimpleExpression()
        while self.Peak in ['<', '==', '>', '>=', '<=', '!=']:
            op = self.Consume(self.Peak)
            ee2 = self.SimpleExpression()
            ee = Binop(ee, op.typ, ee2, op.loc)
        return ee

    def SimpleExpression(self):
        """ Shift operations before + and - ? """
        e = self.AddExpression()
        while self.Peak in ['>>', '<<']:
            op = self.Consume(self.Peak)
            e2 = self.AddExpression()
            e = Binop(e, op.typ, e2, op.loc)
        return e

    def AddExpression(self):
        e = self.Term()
        while self.Peak in ['+', '-']:
            op = self.Consume(self.Peak)
            e2 = self.Term()
            e = Binop(e, op.typ, e2, op.loc)
        return e

    def Term(self):
        t = self.BitwiseOr()
        while self.Peak in ['*', '/']:
            op = self.Consume(self.Peak)
            t2 = self.BitwiseOr()
            t = Binop(t, op.typ, t2, op.loc)
        return t

    def BitwiseOr(self):
        a = self.BitwiseAnd()
        while self.Peak in ['|']:
            op = self.Consume(self.Peak)
            b = self.BitwiseAnd()
            a = Binop(a, op.typ, b, op.loc)
        return a

    def BitwiseAnd(self):
        a = self.CastExpression()
        while self.Peak in ['&']:
            op = self.Consume(self.Peak)
            b = self.CastExpression()
            a = Binop(a, op.typ, b, op.loc)
        return a

    # Domain of unary expressions:

    def CastExpression(self):
        """
          the C-style type cast conflicts with '(' expr ')'
          so introduce extra keyword 'cast'
        """
        if self.Peak == 'cast':
            loc = self.Consume('cast').loc
            self.Consume('<')
            t = self.parseTypeSpec()
            self.Consume('>')
            self.Consume('(')
            ce = self.Expression()
            self.Consume(')')
            return TypeCast(t, ce, loc)
        else:
            return self.UnaryExpression()

    def UnaryExpression(self):
        if self.Peak in ['&', '*']:
            op = self.Consume(self.Peak)
            ce = self.CastExpression()
            if op.val == '*':
                return Deref(ce, op.loc)
            else:
                return Unop(op.typ, ce, op.loc)
        else:
            return self.PostFixExpression()

    def PostFixExpression(self):
        pfe = self.PrimaryExpression()
        if self.hasConsumed('('):
            # Function call
            args = []
            if not self.hasConsumed(')'):
                args.append(self.Expression())
                while self.hasConsumed(','):
                    args.append(self.Expression())
                self.Consume(')')
            pfe = FunctionCall(pfe, args, pfe.loc)
        else:
            while self.Peak in ['[', '.', '->']:
                if self.hasConsumed('['):
                    raise NotImplementedError('Array not yet implemented')
                elif self.hasConsumed('->'):
                    field = self.Consume('ID')
                    pfe = Deref(pfe, pfe.loc)
                    pfe = FieldRef(pfe, field.val, field.loc)
                elif self.hasConsumed('.'):
                    field = self.Consume('ID')
                    pfe = FieldRef(pfe, field.val, field.loc)
                else:
                    raise Exception()
        return pfe

    def PrimaryExpression(self):
        if self.hasConsumed('('):
            e = self.Expression()
            self.Consume(')')
            return e
        elif self.Peak == 'NUMBER':
            val = self.Consume('NUMBER')
            return Literal(val.val, val.loc)
        elif self.Peak == 'REAL':
            val = self.Consume('REAL')
            return Literal(val.val, val.loc)
        elif self.Peak == 'true':
            val = self.Consume('true')
            return Literal(True, val.loc)
        elif self.Peak == 'false':
            val = self.Consume('false')
            return Literal(False, val.loc)
        elif self.Peak == 'ID':
            d = self.parseDesignator()
            return VariableUse(d, d.loc)
        self.Error('Expected NUM, ID or (expr), got {0}'.format(self.Peak))