diff python/ppci/codegen/canon.py @ 301:6753763d3bec

merge codegen into ppci package
author Windel Bouwman
date Thu, 05 Dec 2013 17:02:38 +0100
parents python/codegen/canon.py@158068af716c
children e609d5296ee9
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/ppci/codegen/canon.py	Thu Dec 05 17:02:38 2013 +0100
@@ -0,0 +1,129 @@
+from .. import ir
+from itertools import chain
+
+def make(function, frame):
+    """
+        Create canonicalized version of the IR-code. This means:
+        - Calls out of expressions.
+        - Other things?
+    """
+    # Change the tree. This modifies the IR-tree!
+    # Move all parameters into registers
+    parmoves = []
+    for p in function.arguments:
+        pt = newTemp()
+        frame.parMap[p] = pt
+        parmoves.append(ir.Move(pt, frame.argLoc(p.num)))
+    function.entry.instructions = parmoves + function.entry.instructions
+
+    for block in function.Blocks:
+        for stmt in block.instructions:
+            rewriteStmt(stmt, frame)
+        linearize(block)
+    # TODO: schedule here?
+
+# Visit all nodes with some function:
+# TODO: rewrite into visitor.
+
+# Rewrite rewrites call instructions into Eseq instructions.
+
+
+def rewriteStmt(stmt, frame):
+    if isinstance(stmt, ir.Jump):
+        pass
+    elif isinstance(stmt, ir.CJump):
+        stmt.a = rewriteExp(stmt.a, frame)
+        stmt.b = rewriteExp(stmt.b, frame)
+    elif isinstance(stmt, ir.Move):
+        stmt.src = rewriteExp(stmt.src, frame)
+        stmt.dst = rewriteExp(stmt.dst, frame)
+    elif isinstance(stmt, ir.Terminator):
+        pass
+    elif isinstance(stmt, ir.Exp):
+        stmt.e = rewriteExp(stmt.e, frame)
+    else:
+        raise NotImplementedError('STMT NI: {}'.format(stmt))
+
+newTemp = ir.NamedClassGenerator('canon_reg', ir.Temp).gen
+
+def rewriteExp(exp, frame):
+    if isinstance(exp, ir.Binop):
+        exp.a = rewriteExp(exp.a, frame)
+        exp.b = rewriteExp(exp.b, frame)
+        return exp
+    elif isinstance(exp, ir.Const):
+        return exp
+    elif isinstance(exp, ir.Temp):
+        return exp
+    elif isinstance(exp, ir.Parameter):
+        return frame.parMap[exp]
+    elif isinstance(exp, ir.LocalVariable):
+        offset = frame.allocVar(exp)
+        return ir.Add(frame.fp, ir.Const(offset))
+    elif isinstance(exp, ir.Mem):
+        exp.e = rewriteExp(exp.e, frame)
+        return exp
+    elif isinstance(exp, ir.Call):
+        exp.arguments = [rewriteExp(p, frame) for p in exp.arguments]
+        # Rewrite call into eseq:
+        t = newTemp()
+        return ir.Eseq(ir.Move(t, exp), t)
+    else:
+        raise NotImplementedError('NI: {}'.format(exp))
+        
+# The flatten functions pull out seq instructions to the sequence list.
+
+def flattenExp(exp):
+    if isinstance(exp, ir.Binop):
+        exp.a, sa = flattenExp(exp.a)
+        exp.b, sb = flattenExp(exp.b)
+        return exp, sa + sb
+    elif isinstance(exp, ir.Temp):
+        return exp, []
+    elif isinstance(exp, ir.Const):
+        return exp, []
+    elif isinstance(exp, ir.Mem):
+        exp.e, s = flattenExp(exp.e)
+        return exp, s
+    elif isinstance(exp, ir.Eseq):
+        s = flattenStmt(exp.stmt)
+        exp.e, se = flattenExp(exp.e)
+        return exp.e, s + se
+    elif isinstance(exp, ir.Call):
+        sp = []
+        p = []
+        for p_, sp_ in (flattenExp(p) for p in exp.arguments):
+            p.append(p_)
+            sp.extend(sp_)
+        exp.arguments = p
+        return exp, sp
+    else:
+        raise NotImplementedError('NI: {}'.format(exp))
+
+
+def flattenStmt(stmt):
+    if isinstance(stmt, ir.Jump):
+        return [stmt]
+    elif isinstance(stmt, ir.CJump):
+        stmt.a, sa = flattenExp(stmt.a)
+        stmt.b, sb = flattenExp(stmt.b)
+        return sa + sb + [stmt]
+    elif isinstance(stmt, ir.Move):
+        stmt.dst, sd = flattenExp(stmt.dst)
+        stmt.src, ss = flattenExp(stmt.src)
+        return sd + ss + [stmt]
+    elif isinstance(stmt, ir.Terminator):
+        return [stmt]
+    elif isinstance(stmt, ir.Exp):
+        stmt.e, se = flattenExp(stmt.e)
+        return se + [stmt]
+    else:
+        raise NotImplementedError('STMT NI: {}'.format(stmt))
+
+
+def linearize(block):
+    """ 
+      Move seq instructions to top and flatten these in an instruction list
+    """
+    i = list(flattenStmt(s) for s in block.instructions)
+    block.instructions = list(chain.from_iterable(i))