view python/transform.py @ 254:bd26dc13f270

Added logger
author Windel Bouwman
date Wed, 31 Jul 2013 21:20:58 +0200
parents 74c6a20302d5
children 7416c923a02a
line wrap: on
line source

"""
 Transformation to optimize IR-code
"""

import logging
from ir import *
# Standard passes:

class FunctionPass:
    def run(self, ir):
        """ Main entry point for the pass """
        logging.info('Running pass {}'.format(type(self)))
        self.prepare()
        for f in ir.Functions:
            self.onFunction(f)

    def onFunction(self, f):
        """ Override this virtual method """
        raise NotImplementedError()

    def prepare(self):
        pass


class BasicBlockPass(FunctionPass):
    def onFunction(self, f):
        for bb in f.BasicBlocks:
            self.onBasicBlock(bb)

    def onBasicBlock(self, bb):
        """ Override this virtual method """
        raise NotImplementedError()


class InstructionPass(BasicBlockPass):
    def onBasicBlock(self, bb):
        for ins in iter(bb.Instructions):
            self.onInstruction(ins)

    def onInstruction(self, ins):
        """ Override this virtual method """
        raise NotImplementedError()

# Usefull transforms:
class ConstantFolder(InstructionPass):
    def prepare(self):
        self.constMap = {}

    def onInstruction(self, i):
      if type(i) is ImmLoad:
         self.constMap[i.target] = i.value
      elif type(i) is BinaryOperator:
         if i.value1 in self.constMap and i.value2 in self.constMap and i.operation in ['+', '-', '*', '<<']:
            op = i.operation
            va = self.constMap[i.value1]
            vb = self.constMap[i.value2]
            if op == '+':
               vr = va + vb
            elif op == '*':
               vr = va * vb
            elif op == '-':
               vr = va - vb
            elif op == '<<':
                vr = va << vb
            else:
               vr = None
               return
            self.constMap[i.result] = vr
            i.removeDef(i.result)
            i2 = ImmLoad(i.result, vr)
            logging.debug('Replacing {}'.format(i))
            i.Parent.replaceInstruction(i, i2)


class DeadCodeDeleter(BasicBlockPass):
    def onBasicBlock(self, bb):
        def instructionUsed(ins):
            if not type(ins) in [ImmLoad, BinaryOperator]:
                return True
            if len(ins.defs) == 0:
                # In case this instruction does not define any 
                # variables, assume it is usefull.
                return True
            return any(d.Used for d in ins.defs)

        change = True
        while change:
            change = False
            for i in bb.Instructions:
                if instructionUsed(i):
                    continue
                bb.removeInstruction(i)
                change = True


class CommonSubexpressionElimination(BasicBlockPass):
    def onBasicBlock(self, bb):
        constMap = {}
        to_remove = []
        for i in bb.Instructions:
            if isinstance(i, ImmLoad):
                if i.value in constMap:
                    t_new = constMap[i.value]
                    t_old = i.target
                    logging.debug('Replacing {} with {}'.format(t_old, t_new))
                    for ui in t_old.used_by:
                        ui.replaceValue(t_old, t_new)
                    to_remove.append(i)
                else:
                    constMap[i.value] = i.target
            elif isinstance(i, BinaryOperator):
                k = (i.value1, i.operation, i.value2)
                if k in constMap:
                    t_old = i.result
                    t_new = constMap[k]
                    logging.debug('Replacing {} with {}'.format(t_old, t_new))
                    for ui in t_old.used_by:
                        ui.replaceValue(t_old, t_new)
                    to_remove.append(i)
                else:
                    constMap[k] = i.result
        for i in to_remove:
            logging.debug('removing {}'.format(i))
            bb.removeInstruction(i)


class CleanPass(FunctionPass):
    def onFunction(self, f):
        bbs = list(f.BasicBlocks)
        for bb in bbs:
            # If a block only contains a branch, it can be removed:
            if len(bb.Instructions) == 1 and type(bb.LastInstruction) is Branch:
                # This block is empty.
                # find predecessors of this block and replace this block reference with the jumped reference.
                ins = bb.LastInstruction
                preds = bb.Predecessors
                if bb in preds:
                    # Do not remove if preceeded by itself
                    pass
                else:
                    for pred in bb.Predecessors:
                          pred.LastInstruction.changeTarget(bb, ins.target)
                    f.removeBasicBlock(bb)