changeset 258:04c19282a5aa

Added register allocator
author Windel Bouwman
date Mon, 05 Aug 2013 19:46:11 +0200
parents 703321743e8a
children ac603eb66b63
files python/codegenarm.py python/cortexm3.py python/ide.py python/ir/basicblock.py python/ir/instruction.py python/ir/module.py python/outstream.py python/zcc.py
diffstat 8 files changed, 77 insertions(+), 48 deletions(-) [+]
line wrap: on
line diff
--- a/python/codegenarm.py	Sun Aug 04 18:34:06 2013 +0200
+++ b/python/codegenarm.py	Mon Aug 05 19:46:11 2013 +0200
@@ -11,7 +11,7 @@
     """
     def __init__(self, out):
         self.outs = out
-        self.logger = logging.getLogger('cgarm')
+        self.logger = logging.getLogger('codegenarm')
 
     def emit(self, item):
         self.outs.emit(item)
@@ -19,6 +19,8 @@
     def generate(self, ircode):
         assert isinstance(ircode, ir.Module)
         self.logger.info('Generating arm code for {}'.format(ircode.name))
+        self.available_regs = {arm.r2, arm.r3, arm.r4, arm.r5, arm.r6, arm.r7}
+        self.regmap = {}
         # TODO: get these from linker descriptor?
         self.outs.getSection('code').address = 0x08000000
         self.outs.getSection('data').address = 0x20000000
@@ -62,8 +64,8 @@
             self.align()
         self.outs.backpatch()
         self.outs.backpatch()
-        code = self.outs.getSection('code').to_bytes()
-        self.logger.info('Generated {} bytes code'.format(len(code)))
+        codesize = self.outs.getSection('code').Size
+        self.logger.info('Generated {} bytes code'.format(codesize))
 
     def dcd(self, x):
         self.emit(arm.dcd_ins(Imm32(x)))
@@ -87,6 +89,17 @@
     def loadStack(self, reg, val):
         self.emit(arm.ldr_sprel(reg, arm.MemSpRel(self.getStack(val))))
 
+    def getreg(self, v):
+        if not v in self.regmap:
+            self.regmap[v] = self.available_regs.pop()
+        return self.regmap[v]
+
+    def freereg(self, v, ins):
+        if v.lastUse(ins):
+            r = self.regmap.pop(v)
+            assert r not in self.regmap.values()
+            self.available_regs.add(r)
+
     def comment(self, txt):
         self.emit(Comment(txt))
 
@@ -103,53 +116,58 @@
             self.emit(arm.b_ins(tgt))
         elif type(ins) is ir.ImmLoad:
             lname = ins.target.name + '_ivalue'
-            self.emit(arm.ldr_pcrel(arm.r0, LabelRef(lname)))
+            r0 = self.getreg(ins.target)
+            self.emit(arm.ldr_pcrel(r0, LabelRef(lname)))
             self.imms.append((lname, ins.value))
-            self.emit(arm.str_sprel(arm.r0, arm.MemSpRel(self.addStack(ins.target))))
         elif type(ins) is ir.Store:
             # Load value in r0:
-            self.loadStack(arm.r0, ins.value)
+            r0 = self.getreg(ins.value)
             # store in memory:
             # TODO: split globals and locals??
             #self.getGlobal(arm.r1, ins.location)
             # Horrible hack with localVars
             if ins.location in self.localVars:
                 # The value was alloc'ed
-                self.emit(arm.str_sprel(arm.r0, arm.MemSpRel(self.getStack(ins.location))))
+                self.emit(arm.str_sprel(r0, arm.MemSpRel(self.getStack(ins.location))))
             else:
-                self.loadStack(arm.r1, ins.location)
-                self.emit(arm.storeimm5_ins(arm.r0, arm.MemR8Rel(arm.r1, 0)))
+                r1 = self.getreg(ins.location)
+                self.emit(arm.storeimm5_ins(r0, arm.MemR8Rel(r1, 0)))
+            self.freereg(ins.location, ins)
+            self.freereg(ins.value, ins)
         elif type(ins) is ir.Load:
             # TODO: differ global and local??
             #self.getGlobal(arm.r0, ins.location)
+            r0 = self.getreg(ins.value)
             if ins.location in self.localVars:
-                self.emit(arm.ldr_sprel(arm.r0, arm.MemSpRel(self.getStack(ins.location))))
+                self.emit(arm.ldr_sprel(r0, arm.MemSpRel(self.getStack(ins.location))))
             else:
-                self.loadStack(arm.r0, ins.location)
-                self.emit(arm.loadimm5_ins(arm.r0, arm.MemR8Rel(arm.r0, 0)))
-            # Store value on stack:
-            self.emit(arm.str_sprel(arm.r0, arm.MemSpRel(self.addStack(ins.value))))
+                r2 = self.getreg(ins.location)
+                self.emit(arm.loadimm5_ins(r0, arm.MemR8Rel(r2, 0)))
+            self.freereg(ins.location, ins)
         elif type(ins) is ir.BinaryOperator:
             # Load operands:
-            self.loadStack(arm.r0, ins.value1)
-            self.loadStack(arm.r1, ins.value2)
+            r0 = self.getreg(ins.value1)
+            r1 = self.getreg(ins.value2)
+            r2 = self.getreg(ins.result)
             # do operation:
             if ins.operation == '+':
-                self.emit(arm.addregs_ins(arm.r0, arm.r0, arm.r1))
+                self.emit(arm.addregs_ins(r2, r0, r1))
             elif ins.operation == '<<':
-                self.emit(arm.lslregs_ins(arm.r0, arm.r1))
+                self.emit(arm.movregreg_ins(r2, r0))
+                self.emit(arm.lslregs_ins(r2, r1))
             elif ins.operation == '|':
-                self.emit(arm.orrregs_ins(arm.r0, arm.r1))
+                self.emit(arm.movregreg_ins(r2, r0))
+                self.emit(arm.orrregs_ins(r2, r1))
             else:
                 raise NotImplementedError('operation {} not implemented'.format(ins.operation))
-            # Store value back:
-            self.emit(arm.str_sprel(arm.r0, arm.MemSpRel(self.addStack(ins.result))))
+            self.freereg(ins.value1, ins)
+            self.freereg(ins.value2, ins)
         elif type(ins) is ir.Return:
             self.emit(arm.pop_ins(arm.RegisterSet({arm.r4, arm.r5, arm.r6, arm.r7, arm.pc})))
         elif type(ins) is ir.ConditionalBranch:
-            self.loadStack(arm.r0, ins.a)
-            self.loadStack(arm.r1, ins.b)
-            self.emit(arm.cmp_ins(arm.r1, arm.r0))
+            r0 = self.getreg(ins.a)
+            r1 = self.getreg(ins.b)
+            self.emit(arm.cmp_ins(r1, r0))
             tgt_yes = Label(ins.lab1.name)
             if ins.cond == '==':
                 self.emit(arm.beq_ins(tgt_yes))
@@ -157,6 +175,8 @@
                 raise NotImplementedError('"{}" not covered'.format(ins.cond))
             tgt_no = Label(ins.lab2.name)
             self.emit(arm.jmp_ins(tgt_no))
+            self.freereg(ins.a, ins)
+            self.freereg(ins.b, ins)
         elif type(ins) is ir.Alloc:
             # Local variables are added to stack
             self.addStack(ins.value)
--- a/python/cortexm3.py	Sun Aug 04 18:34:06 2013 +0200
+++ b/python/cortexm3.py	Mon Aug 05 19:46:11 2013 +0200
@@ -341,19 +341,6 @@
         return 'MOV {0}, xx?'.format(self.r)
 
 
-@armtarget.instruction
-class movregreg_ins(ArmInstruction):
-    """ mov Rd, Rm """
-    mnemonic = 'mov'
-    operands = (Reg8Op, Reg8Op)
-    def __init__(self, rd, rm):
-        self.rd = rd
-        self.rm = rm
-    def encode(self):
-        rd = self.rd.num
-        rm = self.rm.num
-        h = 0 | (rm << 3) | rd
-        return u16(h)
 
 
 
@@ -419,6 +406,12 @@
         return '{} {}, {}'.format(self.mnemonic, self.rdn, self.rm)
 
 @armtarget.instruction
+class movregreg_ins(regreg_base):
+    """ mov Rd, Rm """
+    mnemonic = 'mov'
+    opcode = 0
+
+@armtarget.instruction
 class andregs_ins(regreg_base):
     mnemonic = 'AND'
     opcode = 0b0100000000
--- a/python/ide.py	Sun Aug 04 18:34:06 2013 +0200
+++ b/python/ide.py	Mon Aug 05 19:46:11 2013 +0200
@@ -16,13 +16,11 @@
 import zcc
 import outstream
 
-logformat='%(asctime)s|%(levelname)s|%(name)s|%(msg)s'
-
 class BuildOutput(QTextEdit):
     """ Build output component """
     def __init__(self, parent=None):
         super(BuildOutput, self).__init__(parent)
-        fmt = logging.Formatter(fmt=logformat)
+        fmt = logging.Formatter(fmt=zcc.logformat)
         class MyHandler(logging.Handler):
             def emit(self2, x):
                 self.append(str(fmt.format(x)))
@@ -316,7 +314,7 @@
 
 
 if __name__ == '__main__':
-    logging.basicConfig(format=logformat, level=logging.DEBUG)
+    logging.basicConfig(format=zcc.logformat, level=logging.DEBUG)
     app = QApplication(sys.argv)
     ide = Ide()
     ide.show()
--- a/python/ir/basicblock.py	Sun Aug 04 18:34:06 2013 +0200
+++ b/python/ir/basicblock.py	Mon Aug 05 19:46:11 2013 +0200
@@ -55,6 +55,9 @@
         return preds
     Predecessors = property(getPredecessors)
 
+    def precedes(self, other):
+        raise NotImplementedError()
+
     def check(self):
         for ins in self.Instructions:
             ins.check()
--- a/python/ir/instruction.py	Sun Aug 04 18:34:06 2013 +0200
+++ b/python/ir/instruction.py	Mon Aug 05 19:46:11 2013 +0200
@@ -32,6 +32,9 @@
                 return False
         return True
 
+    def lastUse(self, ins):
+        assert ins in self.used_by
+        return all(not ins.precedes(ub) for ub in self.used_by)
 
 class Variable(Value):
     pass
@@ -97,7 +100,13 @@
 
     @property
     def Position(self):
-        return self.parent.Instructions.index(self)
+        return self.Block.Instructions.index(self)
+
+    def precedes(self, other):
+        assert isinstance(other, Instruction)
+        if self.Block is other.Block:
+            return other.Position > self.Position
+        return self.Block.precedes(other.Block)
 
     @property
     def Function(self):
--- a/python/ir/module.py	Sun Aug 04 18:34:06 2013 +0200
+++ b/python/ir/module.py	Mon Aug 05 19:46:11 2013 +0200
@@ -20,14 +20,14 @@
     Instructions = property(getInstructions)
 
     def getBBs(self):
-      bbs = []
-      for f in self.Functions:
-         bbs += f.BasicBlocks
-      return bbs
+        bbs = []
+        for f in self.Functions:
+            bbs += f.BasicBlocks
+        return bbs
 
     BasicBlocks = property(getBBs)
     def addFunc(self, f):
-      self.funcs.append(f)
+        self.funcs.append(f)
     addFunction = addFunc
 
     def addVariable(self, v):
--- a/python/outstream.py	Sun Aug 04 18:34:06 2013 +0200
+++ b/python/outstream.py	Mon Aug 05 19:46:11 2013 +0200
@@ -25,6 +25,10 @@
             d.extend(insword)
         return bytes(d)
 
+    @property
+    def Size(self):
+        return len(self.to_bytes())
+
     def debugInfos(self):
         di = [i for i in self.instructions if isinstance(i, DebugInfo)]
         return di
--- a/python/zcc.py	Sun Aug 04 18:34:06 2013 +0200
+++ b/python/zcc.py	Mon Aug 05 19:46:11 2013 +0200
@@ -45,8 +45,10 @@
     obj = cg.generate(ircode)
     return True
 
+logformat='%(asctime)s|%(levelname)s|%(name)s|%(message)s'
+
 def main(args):
-    logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', level=args.log)
+    logging.basicConfig(format=logformat, level=args.log)
     src = args.source.read()
     args.source.close()
     diag = ppci.DiagnosticsManager()