changeset 170:4348da5ca307

Cleanup of ir dir
author Windel Bouwman
date Fri, 29 Mar 2013 17:33:17 +0100
parents ee0d30533dae
children 3eb9b9e2958d
files python/c3/codegenerator.py python/c3/typecheck.py python/ir/__init__.py python/ir/asmwriter.py python/ir/bitreader.py python/ir/context.py python/ir/instruction.py python/ir/module.py python/ir/symboltable.py python/ir/value.py python/testc3.py
diffstat 11 files changed, 155 insertions(+), 404 deletions(-) [+]
line wrap: on
line diff
--- a/python/c3/codegenerator.py	Sat Mar 23 18:34:41 2013 +0100
+++ b/python/c3/codegenerator.py	Fri Mar 29 17:33:17 2013 +0100
@@ -1,103 +1,95 @@
 import ir
 from . import astnodes
 
-def genModule(pkg):
-   m = ir.Module(pkg.name)
-   for s in pkg.scope:
-      if type(s) is astnodes.Variable:
-         genGlobal(m, s)
-      elif type(s) is astnodes.Function:
-         genFunction(m, s)
-      else:
-         print(s)
-   return m
-
-def genGlobal(m, var):
-   v = ir.Value()
-   v.name = var.name
-   m.Globals.append(v)
-
-def genFunction(m, fnc):
-   ft = genType(fnc.typ)
-   f = ir.Function(fnc.name, ft)
-   m.Functions.append(f)
-   bb = ir.BasicBlock()
-   f.BasicBlocks.append(bb)
-   genCode(bb, fnc.body)
-   bb.Instructions.append(ir.RetInstruction())
-
-def genCode(bb, code):
-   if type(code) is astnodes.CompoundStatement:
-      for s in code.statements:
-         genCode(bb, s)
-   elif type(code) is astnodes.Assignment:
-      genExprCode(bb, code.rval)
-      # TODO: store
-   elif type(code) is astnodes.IfStatement:
-      genExprCode(bb, code.condition)
-      # TODO: implement IF.
-      t1, t2 = 1, 2
-      b = ir.BranchInstruction(t1, t2)
-      bb.Instructions.append(b)
-      genCode(bb, code.truestatement)
-      genCode(bb, code.falsestatement)
-   elif type(code) is astnodes.FunctionCall:
-      ins = ir.CallInstruction('f', [])
-      bb.Instructions.append(ins)
-   elif type(code) is astnodes.EmptyStatement:
-      pass
-   elif type(code) is astnodes.ReturnStatement:
-      bb.Instructions.append(ir.RetInstruction())
-   else:
-      print('Unknown stmt:', code)
-
 def NumGen():
    a = 0
    while True:
       yield a
       a = a + 1
 
-nums = NumGen()
-def unique():
-   return 'tmp{0}'.format(nums.__next__())
+class NameGenerator:
+   def __init__(self, prefix):
+      self.prefix = prefix
+      self.nums = NumGen()
+   def gen(self):
+      return '{0}{1}'.format(self.prefix, self.nums.__next__())
+
+class CodeGenerator:
+   """ Generates intermediate code from a package """
+   def gencode(self, pkg):
+      assert type(pkg) is astnodes.Package
+      self.m = ir.Module(pkg.name)
+      self.newTmp = NameGenerator('t').gen
+      self.newLab = NameGenerator('lab').gen
+      self.genModule(pkg)
+      return self.m
+
+   # Helpers:
+   def addIns(self, i):
+      self.m.Instructions.append(i)
+   # inner helpers:
+   def genModule(self, pkg):
+      for s in pkg.scope:
+         if type(s) is astnodes.Variable:
+            # TODO
+            pass
+         elif type(s) is astnodes.Function:
+            # TODO: handle arguments
+            # TODO handle return?
+            self.addIns(ir.LabelInstruction(s.name))
+            self.genCode(s.body)
+            self.addIns(ir.RetInstruction())
+         else:
+            print(s)
 
-def genExprCode(bb, code):
-   if type(code) is astnodes.Binop:
-      a = genExprCode(bb, code.a)
-      b = genExprCode(bb, code.b)
-      ops = {'+': 'fadd', '-': 'fsub', '*':'fmul', '/':'fdiv'}
-      if code.op in ops:
-         op = ops[code.op]
+   def genCode(self, code):
+      if type(code) is astnodes.CompoundStatement:
+         for s in code.statements:
+            self.genCode(s)
+      elif type(code) is astnodes.Assignment:
+         re = self.genExprCode(code.rval)
+         self.addIns(ir.StoreInstruction(code.lval, re))
+      elif type(code) is astnodes.IfStatement:
+         cr = self.genExprCode(code.condition)
+         t1, t2, te = self.newLab(), self.newLab(), self.newLab()
+         self.addIns(ir.IfInstruction(cr, t1, t2))
+         self.addIns(ir.LabelInstruction(t1))
+         self.genCode(code.truestatement)
+         self.addIns(ir.BranchInstruction(te))
+         self.addIns(ir.LabelInstruction(t2))
+         self.genCode(code.falsestatement)
+         self.addIns(ir.LabelInstruction(te))
+      elif type(code) is astnodes.FunctionCall:
+         pass
+      elif type(code) is astnodes.EmptyStatement:
+         pass
+      elif type(code) is astnodes.ReturnStatement:
+         pass
+      else:
+         print('Unknown stmt:', code)
+
+   def genExprCode(self, expr):
+      if type(expr) is astnodes.Binop:
+         ra = self.genExprCode(expr.a)
+         rb = self.genExprCode(expr.b)
+         ops = ['+', '-', '*', '/', 'and', 'or']
+         if expr.op in ops:
+            op = expr.op
+            tmp = self.newTmp()
+            ins = ir.BinaryOperator(tmp, op, ra, rb)
+            self.addIns(ins)
+            return tmp
+         else:
+            print('Unknown binop {0}'.format(expr))
+      elif type(expr) is astnodes.Constant:
          tmp = unique()
-         ins = ir.BinaryOperator(tmp, op, a, b)
-         bb.Instructions.append(ins)
+      elif type(expr) is astnodes.VariableUse:
+         tmp = self.newTmp()
+      elif type(expr) is astnodes.Literal:
+         tmp = self.newTmp()
+         ins = ir.MoveInstruction(tmp, expr.val)
+         self.addIns(ins)
          return tmp
       else:
-         print('Unknown binop {0}'.format(code))
-         bb.Instructions.append(ir.BinaryOperator('unk2', code.op, a, b))
-         return 'unk2'
-   elif type(code) is astnodes.Constant:
-      tmp = unique()
-      bb.Instructions.append(ir.LoadInstruction(tmp, code.value))
-      return tmp
-   elif type(code) is astnodes.VariableUse:
-      tmp = unique()
-      ins = ir.LoadInstruction(tmp, code.target.name)
-      return tmp
-   elif type(code) is astnodes.Literal:
-      tmp = unique()
-      ins = ir.LoadInstruction(tmp, code.val)
-      return tmp
-   else:
-      print('Unknown expr:', code)
-      return 'unk'
+         print('Unknown expr:', code)
 
-def genType(t):
-   return ir.Type()
-
-class CodeGenerator:
-   """ Generates intermediate code """
-   def gencode(self, ast):
-      assert type(ast) is astnodes.Package
-      return genModule(ast)
-
--- a/python/c3/typecheck.py	Sat Mar 23 18:34:41 2013 +0100
+++ b/python/c3/typecheck.py	Fri Mar 29 17:33:17 2013 +0100
@@ -61,12 +61,16 @@
       elif type(sym) is Binop:
          if sym.op in ['+', '-', '*', '/']:
             if equalTypes(sym.a.typ, sym.b.typ):
-               sym.typ = sym.a.typ
+               if equalTypes(sym.a.typ, intType):
+                  sym.typ = sym.a.typ
+               else:
+                  self.diag.error('Can only add integers', sym.loc)
+                  sym.typ = intType
             else:
                # assume void here? TODO: throw exception!
                sym.typ = intType
                self.diag.error('Types unequal', sym.loc)
-         elif sym.op in ['>', '<']:
+         elif sym.op in ['>', '<', '==', '<=', '>=']:
             sym.typ = boolType
             if not equalTypes(sym.a.typ, sym.b.typ):
                self.diag.error('Types unequal', sym.loc)
--- a/python/ir/__init__.py	Sat Mar 23 18:34:41 2013 +0100
+++ b/python/ir/__init__.py	Fri Mar 29 17:33:17 2013 +0100
@@ -1,8 +1,3 @@
-from .module import Module, Function, BasicBlock
-from .value import Value
-from .module import Type, FunctionType
-from .module import i8, i16, i32, void
-from .module import printIr
+from .module import Module
 from .instruction import *
 
-
--- a/python/ir/asmwriter.py	Sat Mar 23 18:34:41 2013 +0100
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,43 +0,0 @@
-
-from . import llvmtype
-from .instruction import BinaryOperator
-#typeNames[VoidType] = 'void'
-
-class AsmWriter:
-   def __init__(self):
-      self.typeNames = {}
-      self.typeNames[llvmtype.typeID.Void] = 'void'
-   def printModule(self, module):
-      if module.Identifier:
-         print('; ModuleID = {0}'.format(module.Identifier))
-      # Print functions:
-      for f in module.Functions:
-         self.printFunction(f)
-   def printFunction(self, f):
-      # TODO: if definition:
-
-      t = self.strType(f.ReturnType.tid)
-      args = '()'
-      print('define {0} {1}{2}'.format(t, f.name, args))
-      print('{')
-      for bb in f.BasicBlocks:
-         print('basic block!')
-         self.printBasicBlock(bb)
-      print('}')
-
-   def strType(self, t):
-      return self.typeNames[t]
-      
-   def printBasicBlock(self, bb):
-      if bb.Name:
-         # print label
-         print('{0}:'.format(bb.Name))
-      for instr in bb.Instructions:
-         self.printInstruction(instr)
-   def printInstruction(self, i):
-      print('Instruction!')
-      if isinstance(i, BinaryOperator):
-         print(i.operation, i.value1.Name, i.value2.Name)
-      else:
-         print(i)
-
--- a/python/ir/bitreader.py	Sat Mar 23 18:34:41 2013 +0100
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,134 +0,0 @@
-from .errors import CompilerException
-from .module import Module
-import struct
-
-def enum(**enums):
-   return type('Enum', (), enums)
-
-BitCodes = enum(END_BLOCK=0, ENTER_SUBBLOCK=1)
-
-class BitstreamReader:
-   def __init__(self, f):
-      self.f = f
-      # Initialize the bitreader:
-      self.bitsInCurrent = 32
-      self.curWord = self.getWord()
-      self.curCodeSize = 2
-   def getWord(self):
-      bts = self.f.read(4)
-      return struct.unpack('<I', bts)[0]
-   def Read(self, numbits):
-      if numbits > 32:
-         raise CompilerException("Cannot read more than 32 bits")
-      if self.bitsInCurrent >= numbits:
-         # numbits inside the current word:
-         R = self.curWord & ((1 << numbits) - 1)
-         self.curWord = self.curWord >> numbits
-         self.bitsInCurrent -= numbits
-         return R 
-      R = self.curWord
-      self.curWord = self.getWord()
-      bitsLeft = numbits - self.bitsInCurrent
-      
-      # Add remaining bits:
-      R |= (self.curWord & (0xFFFFFFFF >> (32 - bitsLeft))) << self.bitsInCurrent
-
-      # Update curword and bits in current:
-      self.curWord = self.curWord >> bitsLeft
-      self.bitsInCurrent = 32 - bitsLeft
-      return R
-   def ReadVBR(self, numbits):
-      """ Read variable bits, checking for the last bit is zero. """
-      piece = self.Read(numbits)
-      if (piece & (1 << (numbits - 1))) == 0:
-         return piece
-      result = 0
-      nextbit = 0
-      while True:
-         mask = (1 << (numbits - 1)) - 1
-         result |= ( piece & mask ) << nextbit
-         if (piece & (1 << (numbits - 1))) == 0:
-            return result
-         nextbit += numbits - 1
-         piece = self.Read(numbits)
-   def ReadCode(self):
-      """ Read the code depending on the current code size """
-      return self.Read(self.curCodeSize)
-   def ReadSubBlockId(self):
-      return self.ReadVBR(8)
-   def EnterSubBlock(self, blockId):
-      pass
-
-BLOCKINFO_BLOCKID = 0
-FIRST_APPLICATION_BLOCKID = 8
-MODULE_BLOCKID = FIRST_APPLICATION_BLOCKID
-
-class BitcodeReader:
-   def __init__(self, f):
-      self.stream = BitstreamReader(f)
-   def parseModule(self):
-      for bitsig in [ord('B'), ord('C')]:
-         if self.stream.Read(8) != bitsig:
-            raise CompilerException('Invalid bitcode signature')
-      for bitsig in [0x0, 0xC, 0xE, 0xD]:
-         if self.stream.Read(4) != bitsig:
-            raise CompilerException('Invalid bitcode signature')
-      while True:
-         code = self.stream.ReadCode()
-         if code != BitCodes.ENTER_SUBBLOCK:
-            raise CompilerException('Invalid record at toplevel')
-         blockId = self.stream.ReadSubBlockId()
-         if blockId == MODULE_BLOCKID:
-            print('module block')
-            pass
-         else:
-            print('Block id:', blockId)
-            raise 
-      return Module()
-
-class BitstreamWriter:
-   def __init__(self, f):
-      self.f = f
-      self.u32 = 0
-      self.curpos = 0
-   def Emit1(self, val):
-      self.Emit(val, 1)
-   def Emit(self, val, numbits):
-      """ Emits value using numbits bits """
-      if numbits == 1:
-         if val != 0:
-            self.u32 |= (0x1 << self.curpos)
-         self.curpos += 1
-         if self.curpos == 32:
-            self.writeWord()
-      elif numbits > 1:
-         for i in range(numbits):
-            if val & (1 << i) != 0:
-               self.Emit1(1)
-            else:
-               self.Emit1(0)
-   def writeWord(self):
-      bts = struct.pack('<I', self.u32)
-      self.f.write(bts)
-      self.u32 = 0
-      self.curpos = 0
-   def flush(self):
-      if self.curpos != 0:
-         self.writeWord()
-
-class BitcodeWriter:
-   def __init__(self):
-      pass
-   def WriteModule(self, module):
-      pass
-   def WriteModuleToFile(self, module, f):
-      s = BitstreamWriter(f)
-      s.Emit(ord('B'), 8)
-      s.Emit(ord('C'), 8)
-      s.Emit(0x0, 4)
-      s.Emit(0xC, 4)
-      s.Emit(0xE, 4)
-      s.Emit(0xD, 4)
-      self.WriteModule(module)
-      s.flush()
-
--- a/python/ir/context.py	Sat Mar 23 18:34:41 2013 +0100
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,12 +0,0 @@
-from .llvmtype import IntegerType, llvmType, typeID, FunctionType
-
-class Context:
-   """ Global context """
-   def __init__(self):
-      self.Int8Type = IntegerType(8)
-      self.Int16Type = IntegerType(16)
-      self.Int32Type = IntegerType(32)
-      self.Int64Type = IntegerType(64)
-      self.VoidType = llvmType(typeID.Void)
-      self.DoubleType = llvmType(typeID.Double)
-
--- a/python/ir/instruction.py	Sat Mar 23 18:34:41 2013 +0100
+++ b/python/ir/instruction.py	Fri Mar 29 17:33:17 2013 +0100
@@ -3,6 +3,13 @@
    """ Base class for all instructions. """
    pass
 
+# Label:
+class LabelInstruction(Instruction):
+   def __init__(self, labid):
+      self.labid = labid
+   def __repr__(self):
+      return '{0}:'.format(self.labid)
+
 # Function calling:
 class CallInstruction(Instruction):
    def __init__(self, callee, arguments):
@@ -16,10 +23,15 @@
    def __repr__(self):
       return 'RET'
 
+class MoveInstruction(Instruction):
+   def __init__(self, name, value):
+      self.name = name
+      self.value = value
+   def __repr__(self):
+      return '{0} = {1}'.format(self.name, self.value)
+
 class BinaryOperator(Instruction):
    def __init__(self, name, operation, value1, value2):
-      assert value1
-      assert value2
       #print('operation is in binops:', operation in BinOps)
       # Check types of the two operands:
       self.name = name
@@ -27,25 +39,34 @@
       self.value2 = value2
       self.operation = operation
    def __repr__(self):
-      return '{0} = {1} {2}, {3}'.format(self.name, self.operation, self.value1, self.value2)
+      return '{0} = {2} {1} {3}'.format(self.name, self.operation, self.value1, self.value2)
 
+# Memory functions:
 class LoadInstruction(Instruction):
    def __init__(self, name, value):
       self.value = value
       self.name = name
    def __repr__(self):
-      return 'load {0} = {1}'.format(self.name, self.value)
+      return '{1} = [{0}]'.format(self.name, self.value)
 
 class StoreInstruction(Instruction):
    def __init__(self, name, value):
       self.name = name
       self.value = value
    def __repr__(self):
-      return 'store {0}'.format(self.name)
+      return '[{0}] = {1}'.format(self.name, self.value)
 
 class BranchInstruction(Instruction):
-   def __init__(self, t1, t2):
-      self.t1 = t1
-      self.t2 = t2
+   def __init__(self, target):
+      self.t1 = target
    def __repr__(self):
       return 'BRANCH {0}'.format(self.t1)
+
+class IfInstruction(Instruction):
+   def __init__(self, cond, lab1, lab2):
+      self.cond = cond
+      self.lab1 = lab1
+      self.lab2 = lab2
+   def __repr__(self):
+      return 'IF {0} THEN {1} ELSE {2}'.format(self.cond, self.lab1, self.lab2)
+
--- a/python/ir/module.py	Sat Mar 23 18:34:41 2013 +0100
+++ b/python/ir/module.py	Fri Mar 29 17:33:17 2013 +0100
@@ -1,88 +1,18 @@
-from .symboltable import SymbolTable
-
-# Types:
-class Type:
-   def __init__(self):
-      pass
-      
-class IntegerType(Type):
-   def __init__(self, bits):
-      super().__init__()
-      self.bits = bits
-
-class VoidType(Type):
-   pass
-
-class FunctionType(Type):
-   def __init__(self, resultType, parameterTypes):
-      super().__init__()
-      assert type(parameterTypes) is list
-      self.resultType = resultType
-      self.parameterTypes = parameterTypes
-
-# Default types:
-i8 = IntegerType(8)
-i16 = IntegerType(16)
-i32 = IntegerType(32)
-void = VoidType()
-
 # IR-Structures:
 
 class Module:
-   """ Main container for a piece of code. Contains globals and functions. """
+   """ Main container for a piece of code. """
    def __init__(self, name):
       self.name = name
-      self.functions = [] # Do functions come out of symbol table?
-      self.globs = [] # TODO: are globals in symbol table?
-      self.symtable = SymbolTable()
-   Globals = property(lambda self: self.globs)
-   Functions = property(lambda self: self.functions)
+      self.instructions = []
    def __repr__(self):
-      return 'IR-mod {0}'.format(self.name)
-
-class Argument:
-   def __init__(self, argtype, name, function):
-      self.t = argtype
-      self.name = name
-      self.function = function
-
-class Function:
-   def __init__(self, name, functiontype):
-      super().__init__()
-      self.name = name
-      self.functiontype = functiontype
-
-      self.basicblocks = []
-      self.arguments = []
-   BasicBlocks = property(lambda self: self.basicblocks)
-   Arguments = property(lambda self: self.arguments)
-   ReturnType = property(lambda self: self.functiontype.returnType)
-   FunctionType = property(lambda self: self.functiontype)
-   def __repr__(self):
-      return 'FUNC {0}'.format(self.name)
-   
-class BasicBlock:
-   """ 
-     A basic block represents a sequence of instructions without
-     jumps and branches.
-   """
-   def __init__(self):
-      super().__init__()
-      self.instructions = []
-      self.label = None
+      return 'IR-module [{0}]'.format(self.name)
    def getInstructions(self):
       return self.instructions
    Instructions = property(getInstructions)
+   def dump(self):
+      print(self)
+      for i in self.Instructions:
+         print(i)
+      print('END')
 
-def printIr(md):
-   print(md)
-   for g in md.Globals:
-      print(g)
-   for f in md.Functions:
-      print(f)
-      for bb in f.BasicBlocks:
-         print('{0}:'.format(bb))
-         for ins in bb.Instructions:
-            print(' {0}'.format(ins))
-      print()
-
--- a/python/ir/symboltable.py	Sat Mar 23 18:34:41 2013 +0100
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,7 +0,0 @@
-
-class SymbolTable:
-   """ Holds a table of symbols for a module or function """
-   def __init__(self):
-      self.symbols = {}
-
-
--- a/python/ir/value.py	Sat Mar 23 18:34:41 2013 +0100
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,25 +0,0 @@
-
-class Value:
-   def __init__(self, vty=None):
-      self.valueType = vty
-      self.name = None
-   def getContext(self):
-      return self.valueType.context
-   def dump(self):
-      print(self)
-   def getName(self):
-      return self.name
-   def setName(self, name):
-      if not self.name and not name:
-         return
-      self.name = name
-   Name = property(getName, setName)
-   def __repr__(self):
-      return 'VALUE {0}'.format(self.name)
-
-class Constant(Value):
-   def __init__(self, value, vty):
-      super().__init__(vty)
-      self.value = value
-      print('new constant value: ', value)
-
--- a/python/testc3.py	Sat Mar 23 18:34:41 2013 +0100
+++ b/python/testc3.py	Fri Mar 29 17:33:17 2013 +0100
@@ -56,6 +56,29 @@
 
 """
 
+testsrc2 = """
+package test2;
+
+function void tst()
+{
+   var int a, b;
+   a = 2 * 33 - 12;
+   b = a * 2 + 13;
+   a = b + a;
+   if (a > b and b == 3)
+   {
+      var int x = a;
+      x = b * 2 - a;
+      a = x*x;
+   }
+   else
+   {
+      a = b + a;
+   }
+}
+
+"""
+
 def c3compile(src, diag):
    # Structures:
    builder = c3.Builder(diag)
@@ -176,6 +199,13 @@
       }
       """
       self.builder.build(snippet)
+   def test2(self):
+      # testsrc2 is valid code:
+      self.diag.clear()
+      ir = self.builder.build(testsrc2)
+      print(self.diag.diags)
+      assert ir
+      ir.dump()
 
 if __name__ == '__main__':
    do()