changeset 252:c4370696ccc7

added optimize function
author Windel Bouwman
date Tue, 30 Jul 2013 17:57:46 +0200
parents 6ed3d3a82a63
children 74c6a20302d5
files python/c3/codegenerator.py python/c3/typecheck.py python/ir/basicblock.py python/ir/instruction.py python/testc3.py python/testir.py python/transform.py python/zcc.py
diffstat 8 files changed, 259 insertions(+), 161 deletions(-) [+]
line wrap: on
line diff
--- a/python/c3/codegenerator.py	Mon Jul 29 20:23:13 2013 +0200
+++ b/python/c3/codegenerator.py	Tue Jul 30 17:57:46 2013 +0200
@@ -190,7 +190,7 @@
             self.builder.addIns(ins)
             return self.cast_to_rvalue(expr, tmp2)
         elif type(expr) is astnodes.Literal:
-            tmp = self.builder.newTmp()
+            tmp = self.builder.newTmp('const')
             ins = ir.ImmLoad(tmp, expr.val)
             self.builder.addIns(ins)
             return tmp
--- a/python/c3/typecheck.py	Mon Jul 29 20:23:13 2013 +0200
+++ b/python/c3/typecheck.py	Tue Jul 30 17:57:46 2013 +0200
@@ -159,6 +159,8 @@
                    sym.typ = intType
                    self.error('Types unequal {} != {}'.format(sym.a.typ, sym.b.typ), sym.loc)
             elif sym.op in ['>', '<', '==', '<=', '>=']:
+                expectRval(sym.a)
+                expectRval(sym.b)
                 sym.typ = boolType
                 if not equalTypes(sym.a.typ, sym.b.typ):
                    self.error('Types unequal {} != {}'.format(sym.a.typ, sym.b.typ), sym.loc)
@@ -186,5 +188,5 @@
         elif type(sym) in [CompoundStatement, Package, Function, FunctionType, ExpressionStatement, DefinedType]:
             pass
         else:
-            raise Exception('Unknown type check {0}'.format(sym))
+            raise NotImplementedError('Unknown type check {0}'.format(sym))
 
--- a/python/ir/basicblock.py	Mon Jul 29 20:23:13 2013 +0200
+++ b/python/ir/basicblock.py	Tue Jul 30 17:57:46 2013 +0200
@@ -21,6 +21,7 @@
 
     def removeInstruction(self, i):
         i.parent = None
+        i.delete()
         self.instructions.remove(i)
 
     def getInstructions(self):
--- a/python/ir/instruction.py	Mon Jul 29 20:23:13 2013 +0200
+++ b/python/ir/instruction.py	Tue Jul 30 17:57:46 2013 +0200
@@ -14,9 +14,17 @@
         return '{0}'.format(self.name) # + str(self.IsUsed)
 
     @property
+    def UsedInBlocks(self):
+        bbs = [i.parent for i in self.used_by + [self.Setter]]
+        assert all(isinstance(bb, BasicBlock) for bb in bbs)
+        return set(bbs)
+
+    @property
     def IsUsed(self):
         return len(self.used_by) > 0
 
+    Used = IsUsed
+
     def onlyUsedInBlock(self, bb):
         for use in self.used_by:
             ins = use
@@ -85,24 +93,32 @@
     Parent = property(getParent, setParent)
 
     def replaceValue(self, old, new):
-        raise NotImplementedError()
+        raise NotImplementedError('{}'.format(type(self)))
 
     @property
     def Position(self):
         return self.parent.Instructions.index(self)
 
+    @property
+    def Function(self):
+        return self.Block.parent
+
+    @property
+    def Block(self):
+        return self.Parent
+
     def check(self):
         # Check that the variables defined by this instruction 
         # are only used in the same block
         for v in self.defs:
             assert v.Setter is self
             for ub in v.used_by:
-                assert ub.parent == self.parent
+                assert ub.Function == self.Function
 
         # Check that variables used are defined earlier:
         for u in self.uses:
             v = u.val
-            assert self.Position > v.Setter.Position
+            #assert self.Position > v.Setter.Position
 
 
 
@@ -206,19 +222,20 @@
 
 # Memory functions:
 class Load(Instruction):
-   def __init__(self, location, value):
-      super().__init__()
-      assert type(value) is Value
-      assert isinstance(location, Value), "Location must be a value"
-      self.value = value
-      self.addDef(value)
-      self.location = location
-      self.addUse(self.location)
-   def __repr__(self):
-      return '{} = [{}]'.format(self.value, self.location)
+    def __init__(self, location, value):
+        super().__init__()
+        assert type(value) is Value
+        assert isinstance(location, Value), "Location must be a value"
+        self.value = value
+        self.addDef(value)
+        self.location = location
+        self.addUse(self.location)
+
+    def __repr__(self):
+        return '{} = [{}]'.format(self.value, self.location)
 
 class Store(Instruction):
-   def __init__(self, location, value):
+    def __init__(self, location, value):
       super().__init__()
       assert type(value) is Value, value
       assert isinstance(location, Value), "Location must be a value"
@@ -226,8 +243,19 @@
       self.value = value
       self.addUse(value)
       self.addUse(location)
-   def __repr__(self):
-      return '[{}] = {}'.format(self.location, self.value)
+    
+    def __repr__(self):
+        return '[{}] = {}'.format(self.location, self.value)
+
+    def replaceValue(self, old, new):
+        if old is self.value:
+            self.value = new
+        elif old is self.location:
+            self.location = new
+        else:
+            raise Exception()
+        self.removeUse(old)
+        self.addUse(new)
 
 # Branching:
 class Branch(Terminator):
@@ -270,9 +298,10 @@
          self.lab2 = tto
 
 class PhiNode(Instruction):
-   def __init__(self):
-      super().__init__()
-      self.incBB = []
-   def addIncoming(self, bb):
-      self.incBB.append(bb)
+    def __init__(self):
+        super().__init__()
+        self.incBB = []
 
+    def addIncoming(self, bb):
+        self.incBB.append(bb)
+
--- a/python/testc3.py	Mon Jul 29 20:23:13 2013 +0200
+++ b/python/testc3.py	Tue Jul 30 17:57:46 2013 +0200
@@ -174,11 +174,11 @@
       block_code = """ a0 = alloc
         b1 = alloc
         c2 = alloc
-        t3 = 1
-        [a0] = t3
+        const3 = 1
+        [a0] = const3
         t4 = [a0]
-        t5 = 2
-        mul6 = t4 * t5
+        const5 = 2
+        mul6 = t4 * const5
         t7 = [a0]
         t8 = [a0]
         mul9 = t7 * t8
@@ -187,8 +187,8 @@
         t11 = [b1]
         t12 = [a0]
         mul13 = t11 * t12
-        t14 = 3
-        sub15 = mul13 - t14
+        const14 = 3
+        sub15 = mul13 - const14
         [c2] = sub15
       ret """
       self.expectIR(snippet, block_code)
@@ -286,11 +286,11 @@
         """
         block_code = """a0 = alloc
          b1 = alloc
-         t2 = 2
-         [a0] = t2
+         const2 = 2
+         [a0] = const2
          t3 = [a0]
-         t4 = 2
-         add5 = t3 + t4
+         const4 = 2
+         add5 = t3 + const4
          [b1] = add5
          ret  """
         self.expectIR(snippet, block_code)
@@ -315,15 +315,15 @@
          }
         """
         block_code = """a0 = alloc
-         t1 = 2
+         const1 = 2
          off_x2 = 0
          adr_x3 = a0 + off_x2
-         [adr_x3] = t1
+         [adr_x3] = const1
          off_x4 = 0
          adr_x5 = a0 + off_x4
          t6 = [adr_x5]
-         t7 = 2
-         add8 = t6 + t7
+         const7 = 2
+         add8 = t6 + const7
          off_y9 = 4
          adr_y10 = a0 + off_y9
          [adr_y10] = add8
@@ -382,11 +382,11 @@
          }
         """
         block_code = """a0 = alloc
-         t1 = 40
-         [a0] = t1
-         t2 = 2
+         const1 = 40
+         [a0] = const1
+         const2 = 2
          deref3 = [a0]
-         [deref3] = t2
+         [deref3] = const2
          ret  """
         self.expectIR(snippet, block_code)
 
@@ -403,19 +403,19 @@
          }
         """
         block_code = """a0 = alloc
-         t1 = 40
-         [a0] = t1
-         t2 = 2
+         const1 = 40
+         [a0] = const1
+         const2 = 2
          deref3 = [a0]
          off_x4 = 0
          adr_x5 = deref3 + off_x4
-         [adr_x5] = t2
+         [adr_x5] = const2
          deref6 = [a0]
          off_x7 = 0
          adr_x8 = deref6 + off_x7
          t9 = [adr_x8]
-         t10 = 14
-         sub11 = t9 - t10
+         const10 = 14
+         sub11 = t9 - const10
          deref12 = [a0]
          off_y13 = 4
          adr_y14 = deref12 + off_y13
--- a/python/testir.py	Mon Jul 29 20:23:13 2013 +0200
+++ b/python/testir.py	Tue Jul 30 17:57:46 2013 +0200
@@ -1,4 +1,5 @@
 import unittest, os
+import sys
 import c3, ppci, ir, x86, transform
 
 class ConstantFolderTestCase(unittest.TestCase):
@@ -47,20 +48,23 @@
    b = a * 2 + 13;
    a = b + a;
    cee = a;
-   if (a > b and b *3 - a+8*b== 3*6-b)
+   cee = cee * 2 + a + cee * 2;
+   if (cee + a > b and b *3 - a+8*b== 3*6-b)
    {
       var int x = a;
       x = b * 2 - a;
-      a = x * x * add2(x, 22 - a);
+      a = x * x * (x + 22 - a);
    }
    else
    {
-      a = b + a + add2(a, b);
+      a = b + a + (a + b);
    }
    var int y;
    y = a - b * 53;
 }
+"""
 
+testsrc2 = """
 function int add2(int x, int y)
 {
    var int res;
@@ -86,38 +90,32 @@
 """
 
 if __name__ == '__main__':
-   #unittest.main()
-   #sys.exit()
-   diag = ppci.DiagnosticsManager()
-   builder = c3.Builder(diag)
-   cgenx86 = x86.X86CodeGenSimple(diag)
-   ir = builder.build(testsrc)
-   diag.printErrors(testsrc)
-   ir.dump()
-   ir.check()
-   cf = transform.ConstantFolder()
-   ir.check()
-   dcd = transform.DeadCodeDeleter()
-   ir.check()
-   m2r = transform.Mem2RegPromotor()
-   ir.check()
-   clr = transform.CleanPass()
-   ir.check()
-   cf.run(ir)
-   dcd.run(ir)
-   clr.run(ir)
-   m2r.run(ir)
-   #ir.dump()
+    #unittest.main()
+    #sys.exit()
+    diag = ppci.DiagnosticsManager()
+    builder = c3.Builder(diag)
+    cgenx86 = x86.X86CodeGenSimple(diag)
+    ir = builder.build(testsrc)
+    diag.printErrors(testsrc)
+    ir.check()
+    ir.dump()
+    transform.optimize(ir)
+    print('dump IR')
+    print('dump IR')
+    print('dump IR')
+    print('dump IR')
+    ir.dump()
 
-   # Dump a graphiz file:
-   with open('graaf.gv', 'w') as f:
+    # Dump a graphiz file:
+    with open('graaf.gv', 'w') as f:
       ir.dumpgv(f)
-   os.system('dot -Tpdf -ograaf.pdf graaf.gv')
+    os.system('dot -Tsvg -ograaf.svg graaf.gv')
 
-   asm = cgenx86.genBin(ir)
-   #for a in asm:
-   #   print(a)
-   with open('out.asm', 'w') as f:
+    sys.exit()
+    asm = cgenx86.genBin(ir)
+    #for a in asm:
+    #   print(a)
+    with open('out.asm', 'w') as f:
       f.write('BITS 64\n')
       for a in asm:
          f.write(str(a) + '\n')
--- a/python/transform.py	Mon Jul 29 20:23:13 2013 +0200
+++ b/python/transform.py	Tue Jul 30 17:57:46 2013 +0200
@@ -39,44 +39,14 @@
 
 # Usefull transforms:
 class ConstantFolder(InstructionPass):
-   def prepare(self):
-      self.constMap = {}
-   def onInstruction(self, i):
+    def prepare(self):
+        self.constMap = {}
+
+    def onInstruction(self, i):
       if type(i) is ImmLoad:
          self.constMap[i.target] = i.value
       elif type(i) is BinaryOperator:
-         if i.value1 in self.constMap and i.value2 in self.constMap:
-            op = i.operation
-            va = self.constMap[i.value1]
-            vb = self.constMap[i.value2]
-            if op == '+':
-               vr = va + vb
-            elif op == '*':
-               vr = va * vb
-            elif op == '-':
-               vr = va - vb
-            else:
-               vr = None
-               return
-            self.constMap[i.result] = vr
-            i.removeDef(i.result)
-            i2 = ImmLoad(i.result, vr)
-            i.Parent.replaceInstruction(i, i2)
-
-
-class ConstantMerge(InstructionPass):
-    def prepare(self):
-        self.constMap = {}
-    def onInstruction(self, i):
-        if type(i) is ImmLoad:
-            v = i.value
-            if v in self.constMap:
-                # v is already defined, re-use the imm-load from elsewhere
-                pass
-            else:
-                self.constMap[v] = i
-        elif type(i) is BinaryOperator:
-         if i.value1 in self.constMap and i.value2 in self.constMap:
+         if i.value1 in self.constMap and i.value2 in self.constMap and i.operation in ['+', '-', '*', '<<']:
             op = i.operation
             va = self.constMap[i.value1]
             vb = self.constMap[i.value2]
@@ -86,56 +56,67 @@
                vr = va * vb
             elif op == '-':
                vr = va - vb
+            elif op == '<<':
+                vr = va << vb
             else:
                vr = None
                return
             self.constMap[i.result] = vr
+            i.removeDef(i.result)
             i2 = ImmLoad(i.result, vr)
+            print('Replacing', i)
             i.Parent.replaceInstruction(i, i2)
 
 
 class DeadCodeDeleter(BasicBlockPass):
-   def onBasicBlock(self, bb):
-      def instructionUsed(ins):
-         if len(ins.defs) == 0:
-            # In case this instruction does not define any 
-            # variables, assume it is usefull.
-            return True
-         for d in ins.defs:
-            if d.IsUsed:
-               return True
-         return False
-      bb.Instructions = list(filter(instructionUsed, bb.Instructions))
+    def onBasicBlock(self, bb):
+        def instructionUsed(ins):
+            if len(ins.defs) == 0:
+                # In case this instruction does not define any 
+                # variables, assume it is usefull.
+                return True
+            return any(d.Used for d in ins.defs)
+
+        change = True
+        while change:
+            change = False
+            for i in bb.Instructions:
+                if instructionUsed(i):
+                    continue
+                bb.removeInstruction(i)
+                change = True
 
 
-class SameImmLoadDeletePass(BasicBlockPass):
+class CommonSubexpressionElimination(BasicBlockPass):
     def onBasicBlock(self, bb):
         constMap = {}
-        imms = filter(lambda i: isinstance(i, ImmLoad), bb.Instructions)
-        for ins in list(imms):
-            if ins.value in constMap:
-                # remove this immload and update all references to the target
-                t_old = ins.target
-                if not t_old.onlyUsedInBlock(bb):
-                    continue
-                # update all references:
-                t_new = constMap[ins.value]
-                for use in t_old.used_by:
-                    use.replaceValue(t_old, t_new)
-                bb.removeInstruction(ins)
-            else:
-                constMap[ins.value] = ins.target
+        to_remove = []
+        for i in bb.Instructions:
+            if isinstance(i, ImmLoad):
+                if i.value in constMap:
+                    t_new = constMap[i.value]
+                    t_old = i.target
+                    for ui in t_old.used_by:
+                        ui.replaceValue(t_old, t_new)
+                    to_remove.append(i)
+                else:
+                    constMap[i.value] = i.target
+            elif isinstance(i, BinaryOperator):
+                k = (i.value1, i.operation, i.value2)
+                if k in constMap:
+                    print('Duplicate binop!', i)
+                    t_old = i.result
+                    t_new = constMap[k]
+                    for ui in t_old.used_by:
+                        ui.replaceValue(t_old, t_new)
+                    to_remove.append(i)
+                else:
+                    constMap[k] = i.result
+        for i in to_remove:
+            print('removing ', i)
+            bb.removeInstruction(i)
 
-
-def isAllocPromotable(allocinst):
-   # Check if alloc value is only used by load and store operations.
-   assert type(allocinst) is Alloc
-   for use in ai.value.used_by:
-      if not type(use.user) in [Load, Store]:
-         # TODO: check volatile
-         return False
-         otherUse = True
-   return True
+            
 
 
 class CleanPass(FunctionPass):
@@ -159,8 +140,101 @@
                     f.removeBasicBlock(bb)
 
 
+def isAllocPromotable(allocinst):
+   # Check if alloc value is only used by load and store operations.
+   assert type(allocinst) is Alloc
+   for use in allocinst.value.used_by:
+      if not type(use) in [Load, Store]:
+         # TODO: check volatile
+         return False
+         otherUse = True
+   return True
+
+
 class Mem2RegPromotor(FunctionPass):
-   def onFunction(self, f):
-      # TODO
-      pass
+    def promoteSingleBlock(self, ai):
+        print('Single block:', ai)
+        v = ai.value
+        bb = ai.Block
+
+        # Replace all loads with the value:
+        loads = [i for i in v.used_by if isinstance(i, Load)]
+        stores = [i for i in v.used_by if isinstance(i, Store)]
+        stores.sort(key=lambda s: s.Position)
+        stores.reverse()
+        print(stores)
+
+        for load in loads:
+            idx = load.Position
+            # Search upwards:
+            for store in stores:
+                if store.Position < load.Position:
+                    break
+            #print('replace {} with {}'.format(load, store.value))
+            for use_ins in load.value.used_by:
+                use_ins.replaceValue(load.value, store.value)
+            assert not load.value.Used
+            print('removing {}'.format(load))
+            bb.removeInstruction(load)
+
+        # Remove store instructions:
+        for store in stores:
+            sv = store.value
+            print('removing {}'.format(store))
+            bb.removeInstruction(store)
+            #assert sv.Used
+        
+        # Remove alloca instruction:
+        assert not ai.value.Used, ai.value.used_by
+        bb.removeInstruction(ai)
+            
+
 
+    def promote(self, ai):
+        # Find load operations and replace them with assignments
+        v = ai.value
+        if len(ai.value.UsedInBlocks) == 1:
+            self.promoteSingleBlock(ai)
+            return
+        
+        loads = [i for i in v.used_by if isinstance(i, Load)]
+        stores = [i for i in v.used_by if isinstance(i, Store)]
+
+        # Each store instruction can be removed (later).
+        # Instead of storing the value, we use it 
+        # where the load would have been!
+        replMap = {}
+        for store in stores:
+            replMap[store] = store.value
+
+        # for each load, track back what the defining store
+        # was.
+        for load in loads:
+            print(load)
+
+    def onFunction(self, f):
+        # TODO
+        for bb in f.BasicBlocks:
+            allocs = [i for i in bb.Instructions if isinstance(i, Alloc)]
+            for i in allocs:
+                print(i, isAllocPromotable(i))
+                if isAllocPromotable(i):
+                    self.promote(i)
+
+def optimize(ir):
+    cf = ConstantFolder()
+    dcd = DeadCodeDeleter()
+    m2r = Mem2RegPromotor()
+    clr = CleanPass()
+    cse = CommonSubexpressionElimination()
+    ir.check()
+    cf.run(ir)
+    dcd.run(ir)
+    ir.check()
+    clr.run(ir)
+    ir.check()
+    m2r.run(ir)
+    ir.check()
+    cse.run(ir)
+    ir.check()
+
--- a/python/zcc.py	Mon Jul 29 20:23:13 2013 +0200
+++ b/python/zcc.py	Tue Jul 30 17:57:46 2013 +0200
@@ -3,7 +3,7 @@
 import sys, argparse
 import c3, ppci, codegen
 import codegenarm
-from transform import CleanPass, SameImmLoadDeletePass
+import transform
 import outstream
 import hexfile
 
@@ -23,13 +23,7 @@
         return
 
     # Optimization passes:
-    ircode.check()
-    cp = CleanPass()
-    cp.run(ircode)
-    ircode.check()
-    sidp = SameImmLoadDeletePass()
-    sidp.run(ircode)
-    ircode.check()
+    transform.optimize(ircode)
 
     if dumpir:
         ircode.dump()