changeset 171:3eb9b9e2958d

Improved IR code
author Windel Bouwman
date Wed, 03 Apr 2013 22:20:20 +0200
parents 4348da5ca307
children 5a7d37d615ee
files python/c3/codegenerator.py python/ir/__init__.py python/ir/basicblock.py python/ir/builder.py python/ir/instruction.py python/ir/module.py python/testir.py python/x86.py
diffstat 8 files changed, 405 insertions(+), 95 deletions(-) [+]
line wrap: on
line diff
--- a/python/c3/codegenerator.py	Fri Mar 29 17:33:17 2013 +0100
+++ b/python/c3/codegenerator.py	Wed Apr 03 22:20:20 2013 +0200
@@ -1,32 +1,17 @@
 import ir
 from . import astnodes
-
-def NumGen():
-   a = 0
-   while True:
-      yield a
-      a = a + 1
-
-class NameGenerator:
-   def __init__(self, prefix):
-      self.prefix = prefix
-      self.nums = NumGen()
-   def gen(self):
-      return '{0}{1}'.format(self.prefix, self.nums.__next__())
-
+from .scope import boolType
+   
 class CodeGenerator:
    """ Generates intermediate code from a package """
    def gencode(self, pkg):
       assert type(pkg) is astnodes.Package
-      self.m = ir.Module(pkg.name)
-      self.newTmp = NameGenerator('t').gen
-      self.newLab = NameGenerator('lab').gen
+      self.builder = ir.Builder()
+      m = ir.Module(pkg.name)
+      self.builder.setModule(m)
       self.genModule(pkg)
-      return self.m
+      return m
 
-   # Helpers:
-   def addIns(self, i):
-      self.m.Instructions.append(i)
    # inner helpers:
    def genModule(self, pkg):
       for s in pkg.scope:
@@ -36,9 +21,10 @@
          elif type(s) is astnodes.Function:
             # TODO: handle arguments
             # TODO handle return?
-            self.addIns(ir.LabelInstruction(s.name))
+            bb = self.builder.newBB()
+            self.builder.setBB(bb)
             self.genCode(s.body)
-            self.addIns(ir.RetInstruction())
+            self.builder.addIns(ir.Return())
          else:
             print(s)
 
@@ -48,17 +34,19 @@
             self.genCode(s)
       elif type(code) is astnodes.Assignment:
          re = self.genExprCode(code.rval)
-         self.addIns(ir.StoreInstruction(code.lval, re))
+         self.builder.addIns(ir.Store(code.lval, re))
       elif type(code) is astnodes.IfStatement:
-         cr = self.genExprCode(code.condition)
-         t1, t2, te = self.newLab(), self.newLab(), self.newLab()
-         self.addIns(ir.IfInstruction(cr, t1, t2))
-         self.addIns(ir.LabelInstruction(t1))
+         bbtrue = self.builder.newBB()
+         bbfalse = self.builder.newBB()
+         te = self.builder.newBB()
+         self.genCondCode(code.condition, bbtrue, bbfalse)
+         self.builder.setBB(bbtrue)
          self.genCode(code.truestatement)
-         self.addIns(ir.BranchInstruction(te))
-         self.addIns(ir.LabelInstruction(t2))
+         self.builder.addIns(ir.Branch(te))
+         self.builder.setBB(bbfalse)
          self.genCode(code.falsestatement)
-         self.addIns(ir.LabelInstruction(te))
+         self.builder.addIns(ir.Branch(te))
+         self.builder.setBB(te)
       elif type(code) is astnodes.FunctionCall:
          pass
       elif type(code) is astnodes.EmptyStatement:
@@ -67,28 +55,63 @@
          pass
       else:
          print('Unknown stmt:', code)
-
+   def genCondCode(self, expr, bbtrue, bbfalse):
+      # Implement sequential logical operators
+      assert expr.typ == boolType
+      if type(expr) is astnodes.Binop:
+         if expr.op == 'or':
+            l2 = self.builder.newBB()
+            self.genCondCode(expr.a, bbtrue, l2)
+            self.builder.setBB(l2)
+            self.genCondCode(expr.b, bbtrue, bbfalse)
+         elif expr.op == 'and':
+            l2 = self.builder.newBB()
+            self.genCondCode(expr.a, l2, bbfalse)
+            self.builder.setBB(l2)
+            self.genCondCode(expr.b, bbtrue, bbfalse)
+         elif expr.op in ['==', '>', '<']:
+            ta = self.genExprCode(expr.a)
+            tb = self.genExprCode(expr.b)
+            i = ir.ConditionalBranch(ta, expr.op, tb, bbtrue, bbfalse)
+            self.builder.addIns(i)
+         else:
+            raise NotImlementedError()
+            print('Unknown cond', expr)
+      elif type(expr) is astnodes.Literal:
+         if expr.val:
+            self.builder.addIns(ir.BranchInstruction(bbtrue))
+         else:
+            self.builder.addIns(ir.BranchInstruction(bbfalse))
+      else:
+         print('Unknown cond', expr)
    def genExprCode(self, expr):
       if type(expr) is astnodes.Binop:
          ra = self.genExprCode(expr.a)
          rb = self.genExprCode(expr.b)
+         tmp = self.builder.newTmp()
          ops = ['+', '-', '*', '/', 'and', 'or']
          if expr.op in ops:
             op = expr.op
-            tmp = self.newTmp()
             ins = ir.BinaryOperator(tmp, op, ra, rb)
-            self.addIns(ins)
+            self.builder.addIns(ins)
             return tmp
          else:
-            print('Unknown binop {0}'.format(expr))
+            print('Unknown {0}'.format(expr))
+            # TODO
+            return tmp
       elif type(expr) is astnodes.Constant:
-         tmp = unique()
+         tmp = self.builder.newTmp()
+         # TODO
+         return tmp
       elif type(expr) is astnodes.VariableUse:
-         tmp = self.newTmp()
+         tmp = self.builder.newTmp()
+         i = ir.Load(expr, tmp)
+         self.builder.addIns(i)
+         return tmp
       elif type(expr) is astnodes.Literal:
-         tmp = self.newTmp()
-         ins = ir.MoveInstruction(tmp, expr.val)
-         self.addIns(ins)
+         tmp = self.builder.newTmp()
+         ins = ir.ImmLoad(tmp, expr.val)
+         self.builder.addIns(ins)
          return tmp
       else:
          print('Unknown expr:', code)
--- a/python/ir/__init__.py	Fri Mar 29 17:33:17 2013 +0100
+++ b/python/ir/__init__.py	Wed Apr 03 22:20:20 2013 +0200
@@ -1,3 +1,4 @@
-from .module import Module
+from .module import *
 from .instruction import *
+from .builder import Builder
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/ir/basicblock.py	Wed Apr 03 22:20:20 2013 +0200
@@ -0,0 +1,23 @@
+
+class BasicBlock:
+   # Uninterrupted sequence of instructions.
+   def __init__(self, name):
+      self.name = name
+      self.instructions = []
+   def __repr__(self):
+      return 'BB {0}'.format(self.name)
+   def addIns(self, i):
+      self.instructions.append(i)
+   def getInstructions(self):
+      return self.instructions
+   Instructions = property(getInstructions)
+   def getLastIns(self):
+      return self.instructions[-1]
+   LastIns = property(getLastIns)
+   @property
+   def Empty(self):
+      return len(self.instructions) == 0
+   @property
+   def FirstIns(self):
+      return self.instructions[0]
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/ir/builder.py	Wed Apr 03 22:20:20 2013 +0200
@@ -0,0 +1,47 @@
+from . import Value, BasicBlock
+
+class NameGenerator:
+   def __init__(self, prefix):
+      self.prefix = prefix
+      def NumGen():
+         a = 0
+         while True:
+            yield a
+            a = a + 1
+      self.nums = NumGen()
+   def gen(self):
+      return '{0}{1}'.format(self.prefix, self.nums.__next__())
+
+class ValueGenerator(NameGenerator):
+   def __init__(self):
+      super().__init__('t')
+   def gen(self):
+      v = Value(super().gen())
+      return v
+
+class BBGenerator(NameGenerator):
+   def __init__(self):
+      super().__init__('lab')
+   def gen(self):
+      v = BasicBlock(super().gen())
+      return v
+
+class Builder:
+   def __init__(self):
+      self.newTmp = ValueGenerator().gen
+      self.newBBint = BBGenerator().gen
+      self.bb = None
+      self.m = None
+
+   # Helpers:
+   def newBB(self):
+      bb = self.newBBint()
+      self.m.addBB(bb)
+      return bb
+   def setModule(self, m):
+      self.m = m
+   def setBB(self, bb):
+      self.bb = bb
+   def addIns(self, i):
+      self.bb.addIns(i)
+
--- a/python/ir/instruction.py	Fri Mar 29 17:33:17 2013 +0100
+++ b/python/ir/instruction.py	Wed Apr 03 22:20:20 2013 +0200
@@ -1,17 +1,35 @@
+from .basicblock import BasicBlock
+
+class Value:
+   """ Temporary SSA value (value that is assigned only once! """
+   def __init__(self, name):
+      # TODO: add typing? for now only handle integers
+      self.name = name
+      self.interferes = set()
+      self.reg = None
+   def __repr__(self):
+      if self.reg:
+         n = self.reg
+      else:
+         n = self.name
+      return '{0}'.format(n)
 
 class Instruction:
    """ Base class for all instructions. """
-   pass
-
-# Label:
-class LabelInstruction(Instruction):
-   def __init__(self, labid):
-      self.labid = labid
-   def __repr__(self):
-      return '{0}:'.format(self.labid)
+   def __init__(self):
+      # successors:
+      self.succ = set()
+      # predecessors:
+      self.pred = set()
+      # live variables at this node:
+      self.live_in = set()
+      self.live_out = set()
+      # What variables this instruction uses and defines:
+      self.defs = set()
+      self.uses = set()
 
 # Function calling:
-class CallInstruction(Instruction):
+class Call(Instruction):
    def __init__(self, callee, arguments):
       super().__init__()
       self.callee = callee
@@ -19,54 +37,84 @@
    def __repr__(self):
       return 'CALL {0}'.format(self.callee)
 
-class RetInstruction(Instruction):
+class Return(Instruction):
    def __repr__(self):
       return 'RET'
 
-class MoveInstruction(Instruction):
-   def __init__(self, name, value):
-      self.name = name
+class ImmLoad(Instruction):
+   def __init__(self, target, value):
+      super().__init__()
+      self.target = target
       self.value = value
+      self.defs.add(target)
    def __repr__(self):
-      return '{0} = {1}'.format(self.name, self.value)
+      return '{0} = {1}'.format(self.target, self.value)
 
+# Data operations
 class BinaryOperator(Instruction):
-   def __init__(self, name, operation, value1, value2):
+   def __init__(self, result, operation, value1, value2):
+      super().__init__()
       #print('operation is in binops:', operation in BinOps)
       # Check types of the two operands:
-      self.name = name
+      self.result = result
+      self.defs.add(result)
       self.value1 = value1
       self.value2 = value2
+      self.uses.add(value1)
+      self.uses.add(value2)
       self.operation = operation
    def __repr__(self):
-      return '{0} = {2} {1} {3}'.format(self.name, self.operation, self.value1, self.value2)
+      return '{0} = {2} {1} {3}'.format(self.result, self.operation, self.value1, self.value2)
 
 # Memory functions:
-class LoadInstruction(Instruction):
+class Load(Instruction):
    def __init__(self, name, value):
+      super().__init__()
       self.value = value
+      self.defs.add(value)
       self.name = name
    def __repr__(self):
       return '{1} = [{0}]'.format(self.name, self.value)
 
-class StoreInstruction(Instruction):
+class Store(Instruction):
    def __init__(self, name, value):
+      super().__init__()
       self.name = name
       self.value = value
+      self.uses.add(value)
    def __repr__(self):
       return '[{0}] = {1}'.format(self.name, self.value)
 
-class BranchInstruction(Instruction):
+# Branching:
+class Branch(Instruction):
    def __init__(self, target):
-      self.t1 = target
+      super().__init__()
+      assert type(target) is BasicBlock
+      self.target = target
    def __repr__(self):
-      return 'BRANCH {0}'.format(self.t1)
+      return 'BRANCH {0}'.format(self.target)
 
-class IfInstruction(Instruction):
-   def __init__(self, cond, lab1, lab2):
+class ConditionalBranch(Instruction):
+   def __init__(self, a, cond, b, lab1, lab2):
+      super().__init__()
+      self.a = a
+      assert type(a) is Value
       self.cond = cond
+      self.b = b
+      self.uses.add(a)
+      self.uses.add(b)
+      assert type(b) is Value
+      assert type(lab1) is BasicBlock
       self.lab1 = lab1
+      assert type(lab2) is BasicBlock
       self.lab2 = lab2
    def __repr__(self):
-      return 'IF {0} THEN {1} ELSE {2}'.format(self.cond, self.lab1, self.lab2)
+      return 'IF {0} {1} {2} THEN {3} ELSE {4}'.format(self.a, self.cond, self.b, self.lab1, self.lab2)
 
+class PhiNode(Instruction):
+   def __init__(self):
+      super().__init__()
+      self.incBB = []
+   def addIncoming(self, bb):
+      self.incBB.append(bb)
+
--- a/python/ir/module.py	Fri Mar 29 17:33:17 2013 +0100
+++ b/python/ir/module.py	Wed Apr 03 22:20:20 2013 +0200
@@ -1,18 +1,113 @@
 # IR-Structures:
+from .instruction import *
+from .basicblock import BasicBlock
 
 class Module:
    """ Main container for a piece of code. """
    def __init__(self, name):
       self.name = name
-      self.instructions = []
+      self.bbs = []
    def __repr__(self):
       return 'IR-module [{0}]'.format(self.name)
    def getInstructions(self):
-      return self.instructions
+      ins = []
+      for bb in self.bbs:
+         ins += bb.Instructions
+      return ins
    Instructions = property(getInstructions)
+   def addBB(self, bb):
+      self.bbs.append(bb)
+   def getBBs(self):
+      return self.bbs
+   BasicBlocks = property(getBBs)
    def dump(self):
       print(self)
       for i in self.Instructions:
-         print(i)
+         print(i, 'live vars:', list(i.live_in), 'uses', list(i.uses), 'defs', list(i.defs))
       print('END')
+   def dumpgv(self, outf):
+      outf.write('digraph G \n{\n')
+      for i in self.Instructions:
+         outf.write('{0} [label="{1}"];\n'.format(id(i), i))
+         for succ in i.succ:
+            outf.write('"{0}" -> "{1}" [label="{2}"];\n'.format(id(i), id(succ), succ.live_in))
+      outf.write('}\n')
 
+   # Analysis functions:
+   def check(self):
+      """ Perform sanity check on module """
+      for i in self.Instructions:
+         for t in i.defs:
+            assert type(t) is Value, "def must be Value, not {0}".format(type(t))
+         for t in i.uses:
+            assert type(t) is Value, "use must be Value, not {0}".format(type(t))
+   def analyze(self):
+      # Determine pred and succ:
+      def link(a, b):
+         a.succ.add(b)
+         b.pred.add(a)
+      for bb in self.bbs:
+         if not bb.Empty:
+            if len(bb.Instructions) > 1:
+               for i1, i2 in zip(bb.Instructions[:-1], bb.Instructions[1:]):
+                  link(i1, i2)
+            else:
+               print('Only 1 long!', bb, bb.Instructions)
+            if type(bb.LastIns) is ConditionalBranch:
+               link(bb.LastIns, bb.LastIns.lab1.FirstIns)
+               link(bb.LastIns, bb.LastIns.lab2.FirstIns)
+            if type(bb.LastIns) is Branch:
+               link(bb.LastIns, bb.LastIns.target.FirstIns)
+         else:
+            print('Empty!', bb)
+      # We now have cfg
+
+      # Determine liveness:
+      for i in self.Instructions:
+         i.live_in = set()
+         i.live_out = set()
+      for z in range(50):
+         # TODO iterate until converge
+         for i in self.Instructions:
+            lo_mk = i.live_out.difference(i.defs)
+            i.live_in = i.uses.union(lo_mk)
+            lo = set()
+            for s in i.succ:
+               lo = lo.union(s.live_in)
+            i.live_out = lo
+   def constantProp(self):
+      """ Constant propagation. Pre-calculate constant values """
+      for i in self.Instructions:
+         if type(i) is ImmLoad:
+            i.target.constval = i.value
+         elif type(i) is BinaryOperator:
+            a = i.value1
+            b = i.value2
+            if i.value1.constval and i.value2.constval:
+               op = i.operation
+               if op == '+':
+                  i.result.constval = a + b
+               else:
+                  raise NotImplementedError(op)
+            else:
+               i.result.constval = None
+
+   def registerAllocate(self, regs):
+      print(regs)
+      allVals = []
+      # construct interference:
+      for i in self.Instructions:
+         for v in i.live_in:
+            allVals.append(v)
+            for v2 in i.live_in:
+               if v != v2:
+                  v.interferes.add(v2)
+      # assign random registers:
+      print(allVals)
+      regs = set(regs)
+      for v in allVals:
+         takenregs = set([iv.reg for iv in v.interferes])
+         r2 = list(regs.difference(takenregs))
+         # Pick next available:
+         v.reg = r2[0]
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/testir.py	Wed Apr 03 22:20:20 2013 +0200
@@ -0,0 +1,51 @@
+import c3, ppci, ir, x86
+import os
+
+testsrc = """
+package test2;
+
+function void tst()
+{
+   var int a, b;
+   a = 2 * 33 - 12;
+   b = a * 2 + 13;
+   a = b + a;
+   if (a > b and b *3 - a+8*b== 3*6-b)
+   {
+      var int x = a;
+      x = b * 2 - a;
+      a = x * x;
+   }
+   else
+   {
+      a = b + a;
+   }
+   var int y;
+   y = a - b * 53;
+}
+
+"""
+
+if __name__ == '__main__':
+   diag = ppci.DiagnosticsManager()
+   builder = c3.Builder(diag)
+   cgenx86 = x86.X86CodeGen(diag)
+   ir = builder.build(testsrc)
+   ir.check()
+   ir.analyze()
+   #ir.constantProp()
+   ir.dump()
+   asm = cgenx86.genBin(ir)
+   for a in asm:
+      print(a)
+   with open('out.asm', 'w') as f:
+      f.write('BITS 64\n')
+      for a in asm:
+         f.write(str(a) + '\n')
+
+   # Dump a graphiz file:
+   with open('graaf.gv', 'w') as f:
+      ir.dumpgv(f)
+   os.system('dot -Tpdf -ograaf.pdf graaf.gv')
+
+
--- a/python/x86.py	Fri Mar 29 17:33:17 2013 +0100
+++ b/python/x86.py	Wed Apr 03 22:20:20 2013 +0200
@@ -8,12 +8,19 @@
       return '{0}:'.format(self.lab)
 
 class Op:
-   def __init__(self, op, a, b):
+   def __init__(self, op, dst, src):
       self.op = op
-      self.a = a
-      self.b = b
+      self.src = src
+      self.dst = dst
    def __repr__(self):
-      return '{0} {1}, {2}'.format(self.op, self.a, self.b)
+      return '{0} {1}, {2}'.format(self.op, self.dst, self.src)
+
+class Jmp:
+   def __init__(self, j, target):
+      self.j = j
+      self.target = target
+   def __repr__(self):
+      return '{0} {1}'.format(self.j, self.target)
 
 class X86CodeGen:
    def __init__(self, diag):
@@ -22,42 +29,57 @@
 
    def emit(self, i):
       self.asm.append(i)
-   def allocateReg(self, typ):
-      return 'ax'
-   def deallocateReg(self, r):
-      pass
 
-   def genBin(self, i):
+   def genBin(self, ir):
       self.asm = []
-      self.genModule(i)
+      # Allocate registers:
+      ir.registerAllocate(self.regs)
+      self.genModule(ir)
+      return self.asm
 
-   def genModule(self, m):
-      for g in m.Globals:
-         self.emit(AsmLabel(g.name))
-         # Ignore types for now ..
-         self.emit('dw 0')
-      for f in m.Functions:
-         self.genFunction(f)
+   def genModule(self, ir):
+      #for f in ir.Functions:
+      #   self.genFunction(f)
+      for bb in ir.BasicBlocks:
+         self.genBB(bb)
    def genFunction(self, f):
       self.emit('global {0}'.format(f.name))
       self.emit(AsmLabel(f.name))
       for bb in f.BasicBlocks:
          self.genBB(bb)
    def genBB(self, bb):
+      self.emit(AsmLabel(bb.name))
       for i in bb.Instructions:
          self.genIns(i)
    def genIns(self, i):
       if type(i) is ir.BinaryOperator:
-         if i.operation == 'fadd':
-            r = 'rax'
-            self.emit(Op('add', r, '11'))
-      elif type(i) is ir.LoadInstruction:
-         r = 'rbx'
-         self.emit(Op('mov', r, '{0}'.format(i.value)))
-      elif type(i) is ir.RetInstruction:
+         ops = {'+':'add', '-':'sub', '*':'mul'}
+         if i.operation in ops:
+            self.emit(Op('mov', i.result.reg, i.value1.reg))
+            self.emit(Op(ops[i.operation], i.result.reg, i.value2.reg))
+         else:
+            raise NotImplementedError('op {0}'.format(i.operation))
+      elif type(i) is ir.Load:
+         self.emit(Op('mov', i.value, '[{0}]'.format(i.name)))
+      elif type(i) is ir.Return:
          self.emit('ret')
-      elif type(i) is ir.CallInstruction:
+      elif type(i) is ir.Call:
          self.emit('call')
+      elif type(i) is ir.ImmLoad:
+         self.emit(Op('mov', i.target, i.value))
+      elif type(i) is ir.Store:
+         self.emit(Op('mov', '[{0}]'.format(i.name), i.value))
+      elif type(i) is ir.ConditionalBranch:
+         self.emit(Op('cmp', i.a, i.b))
+         jmps = {'>':'jg', '<':'jl', '==':'je'}
+         if i.cond in jmps:
+            j = jmps[i.cond]
+            self.emit(Jmp(j, i.lab1.name))
+         else:
+            raise NotImplementedError('condition {0}'.format(i.cond))
+         self.emit(Jmp('jmp', i.lab2.name))
+      elif type(i) is ir.Branch:
+         self.emit(Jmp('jmp', i.target.name))
       else:
-         print('Unknown ins', i)
+         raise NotImplementedError('{0}'.format(i))