view python/ppci/codegen/canon.py @ 361:614a7f6d4d4d

Fixed test
author Windel Bouwman
date Fri, 14 Mar 2014 16:18:54 +0100
parents 5477e499b039
children c49459768aaa
line wrap: on
line source

from .. import ir
from .. import irutils
from itertools import chain

def make(function, frame):
    """
        Create canonicalized version of the IR-code. This means:
        - Calls out of expressions.
        - Other things?
    """
    # Change the tree. This modifies the IR-tree!
    # Move all parameters into registers
    parmoves = []
    for p in function.arguments:
        pt = newTemp()
        frame.parMap[p] = pt
        parmoves.append(ir.Move(pt, frame.argLoc(p.num)))
    function.entry.instructions = parmoves + function.entry.instructions

    for block in function.Blocks:
        for stmt in block.instructions:
            rewriteStmt(stmt, frame)
        linearize(block)
    # TODO: schedule here?

# Visit all nodes with some function:
# TODO: rewrite into visitor.

# Rewrite rewrites call instructions into Eseq instructions.


def rewriteStmt(stmt, frame):
    if isinstance(stmt, ir.Jump):
        pass
    elif isinstance(stmt, ir.CJump):
        stmt.a = rewriteExp(stmt.a, frame)
        stmt.b = rewriteExp(stmt.b, frame)
    elif isinstance(stmt, ir.Move):
        stmt.src = rewriteExp(stmt.src, frame)
        stmt.dst = rewriteExp(stmt.dst, frame)
    elif isinstance(stmt, ir.Terminator):
        pass
    elif isinstance(stmt, ir.Exp):
        stmt.e = rewriteExp(stmt.e, frame)
    else:
        raise NotImplementedError('STMT NI: {}'.format(stmt))

newTemp = irutils.NamedClassGenerator('canon_reg', ir.Temp).gen

def rewriteExp(exp, frame):
    if isinstance(exp, ir.Binop):
        exp.a = rewriteExp(exp.a, frame)
        exp.b = rewriteExp(exp.b, frame)
        return exp
    elif isinstance(exp, ir.Const):
        return exp
    elif isinstance(exp, ir.Temp):
        return exp
    elif isinstance(exp, ir.Parameter):
        return frame.parMap[exp]
    elif isinstance(exp, ir.LocalVariable):
        offset = frame.allocVar(exp)
        return ir.Add(frame.fp, ir.Const(offset))
    elif isinstance(exp, ir.Mem):
        exp.e = rewriteExp(exp.e, frame)
        return exp
    elif isinstance(exp, ir.Addr):
        exp.e = rewriteExp(exp.e, frame)
        return exp
    elif isinstance(exp, ir.Call):
        exp.arguments = [rewriteExp(p, frame) for p in exp.arguments]
        # Rewrite call into eseq:
        t = newTemp()
        return ir.Eseq(ir.Move(t, exp), t)
    else:
        raise NotImplementedError('NI: {}, {}'.format(exp, type(exp)))
        
# The flatten functions pull out seq instructions to the sequence list.

def flattenExp(exp):
    if isinstance(exp, ir.Binop):
        exp.a, sa = flattenExp(exp.a)
        exp.b, sb = flattenExp(exp.b)
        return exp, sa + sb
    elif isinstance(exp, ir.Temp):
        return exp, []
    elif isinstance(exp, ir.Const):
        return exp, []
    elif isinstance(exp, ir.Mem):
        exp.e, s = flattenExp(exp.e)
        return exp, s
    elif isinstance(exp, ir.Addr):
        exp.e, s = flattenExp(exp.e)
        return exp, s
    elif isinstance(exp, ir.Eseq):
        s = flattenStmt(exp.stmt)
        exp.e, se = flattenExp(exp.e)
        return exp.e, s + se
    elif isinstance(exp, ir.Call):
        sp = []
        p = []
        for p_, sp_ in (flattenExp(p) for p in exp.arguments):
            p.append(p_)
            sp.extend(sp_)
        exp.arguments = p
        return exp, sp
    else:
        raise NotImplementedError('NI: {}'.format(exp))


def flattenStmt(stmt):
    if isinstance(stmt, ir.Jump):
        return [stmt]
    elif isinstance(stmt, ir.CJump):
        stmt.a, sa = flattenExp(stmt.a)
        stmt.b, sb = flattenExp(stmt.b)
        return sa + sb + [stmt]
    elif isinstance(stmt, ir.Move):
        stmt.dst, sd = flattenExp(stmt.dst)
        stmt.src, ss = flattenExp(stmt.src)
        return sd + ss + [stmt]
    elif isinstance(stmt, ir.Terminator):
        return [stmt]
    elif isinstance(stmt, ir.Exp):
        stmt.e, se = flattenExp(stmt.e)
        return se + [stmt]
    else:
        raise NotImplementedError('STMT NI: {}'.format(stmt))


def linearize(block):
    """ 
      Move seq instructions to top and flatten these in an instruction list
    """
    i = list(flattenStmt(s) for s in block.instructions)
    block.instructions = list(chain.from_iterable(i))