diff python/ppci/transform.py @ 300:158068af716c

yafm
author Windel Bouwman
date Tue, 03 Dec 2013 18:00:22 +0100
parents python/transform.py@9417caea2eb3
children 6753763d3bec
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/ppci/transform.py	Tue Dec 03 18:00:22 2013 +0100
@@ -0,0 +1,142 @@
+"""
+ Transformation to optimize IR-code
+"""
+
+import logging
+import ir
+# Standard passes:
+
+class FunctionPass:
+    def __init__(self):
+        self.logger = logging.getLogger('optimize')
+
+    def run(self, ir):
+        """ Main entry point for the pass """
+        self.logger.info('Running pass {}'.format(type(self)))
+        ir.check()
+        self.prepare()
+        for f in ir.Functions:
+            self.onFunction(f)
+        ir.check()
+
+    def onFunction(self, f):
+        """ Override this virtual method """
+        raise NotImplementedError()
+
+    def prepare(self):
+        pass
+
+
+class BasicBlockPass(FunctionPass):
+    def onFunction(self, f):
+        for bb in f.Blocks:
+            self.onBasicBlock(bb)
+
+    def onBasicBlock(self, bb):
+        """ Override this virtual method """
+        raise NotImplementedError()
+
+
+class InstructionPass(BasicBlockPass):
+    def onBasicBlock(self, bb):
+        for ins in iter(bb.Instructions):
+            self.onInstruction(ins)
+
+    def onInstruction(self, ins):
+        """ Override this virtual method """
+        raise NotImplementedError()
+
+
+class BasePass(BasicBlockPass):
+    def onBasicBlock(self, bb):
+        pass
+
+
+# Usefull transforms:
+class ConstantFolder(BasePass):
+    def __init__(self):
+        super().__init__()
+        self.ops = {}
+        self.ops['+'] = lambda x, y: x + y
+        self.ops['-'] = lambda x, y: x - y
+        self.ops['*'] = lambda x, y: x * y
+        self.ops['<<'] = lambda x, y: x << y
+
+    def postExpr(self, expr):
+        if type(i) is BinaryOperator and i.operation in self.ops.keys() and type(i.a) is Const and type(i.b) is Const:
+            vr = self.ops[i.operation](i.a.value, i.b.value)
+            return Const(vr)
+        else:
+            return expr
+
+
+class DeadCodeDeleter(BasicBlockPass):
+    def onBasicBlock(self, bb):
+        def instructionUsed(ins):
+            if not type(ins) in [ImmLoad, BinaryOperator]:
+                return True
+            if len(ins.defs) == 0:
+                # In case this instruction does not define any 
+                # variables, assume it is usefull.
+                return True
+            return any(d.Used for d in ins.defs)
+
+        change = True
+        while change:
+            change = False
+            for i in bb.Instructions:
+                if instructionUsed(i):
+                    continue
+                bb.removeInstruction(i)
+                change = True
+
+
+class CommonSubexpressionElimination(BasicBlockPass):
+    def onBasicBlock(self, bb):
+        constMap = {}
+        to_remove = []
+        for i in bb.Instructions:
+            if isinstance(i, ImmLoad):
+                if i.value in constMap:
+                    t_new = constMap[i.value]
+                    t_old = i.target
+                    logging.debug('Replacing {} with {}'.format(t_old, t_new))
+                    t_old.replaceby(t_new)
+                    to_remove.append(i)
+                else:
+                    constMap[i.value] = i.target
+            elif isinstance(i, BinaryOperator):
+                k = (i.value1, i.operation, i.value2)
+                if k in constMap:
+                    t_old = i.result
+                    t_new = constMap[k]
+                    logging.debug('Replacing {} with {}'.format(t_old, t_new))
+                    t_old.replaceby(t_new)
+                    to_remove.append(i)
+                else:
+                    constMap[k] = i.result
+        for i in to_remove:
+            logging.debug('removing {}'.format(i))
+            bb.removeInstruction(i)
+
+
+class CleanPass(FunctionPass):
+    def onFunction(self, f):
+        removeEmptyBasicBlocks(f)
+
+
+def removeEmptyBlocks(f):
+    """ Remove empty basic blocks from function. """
+    # If a block only contains a branch, it can be removed:
+    empty = lambda b: type(b.FirstInstruction) is ir.Jump
+    empty_blocks = list(filter(empty, f.Blocks))
+    for b in empty_blocks:
+        # Update predecessors
+        preds = b.Predecessors
+        if b not in preds + [f.entry]:
+            # Do not remove if preceeded by itself
+            tgt = b.LastInstruction.target
+            for pred in preds:
+                  pred.LastInstruction.changeTarget(b, tgt)
+            logging.debug('Removing empty block: {}'.format(b))
+            f.removeBlock(b)