view python/ppci/ir.py @ 303:be7f60545368

Final fixups
author Windel Bouwman
date Fri, 06 Dec 2013 12:37:48 +0100
parents 6753763d3bec
children fa99f36fabb5
line wrap: on
line source

"""
Intermediate representation (IR) code classes.
"""


def dumpgv(m, outf):
    print('digraph G ', file=outf)
    print('{', file=outf)
    for f in m.Functions:
     print('{} [label="{}" shape=box3d]'.format(id(f), f),file=outf)
     for bb in f.Blocks:
        contents = str(bb) + '\n'
        contents += '\n'.join([str(i) for i in bb.Instructions])
        outf.write('{0} [shape=note label="{1}"];\n'.format(id(bb), contents))
        for successor in bb.Successors:
           print('"{}" -> "{}"\n'.format(id(bb), id(successor)), file=outf)

     outf.write('"{}" -> "{}" [label="entry"]\n'.format(id(f), id(f.entry)))
    print('}', file=outf)


class Module:
    """ Main container of variables and functions. """
    def __init__(self, name):
        self.name = name
        self.funcs = []
        self.variables = []

    def __repr__(self):
        return 'IR-module [{0}]'.format(self.name)

    def addFunc(self, f):
        self.funcs.append(f)

    addFunction = addFunc

    def addVariable(self, v):
        self.variables.append(v)

    def getVariables(self):
        return self.variables

    Variables = property(getVariables)

    def getFunctions(self):
        return self.funcs

    Functions = property(getFunctions)

    def findFunction(self, name):
        for f in self.funcs:
            if f.name == name:
                return f
        raise KeyError(name)

    getFunction = findFunction

    def dump(self, indent='   '):
        print(self)
        for v in self.Variables:
            print(indent, v)
        for fn in self.Functions:
            fn.dump(indent=indent+'   ')

    # Analysis functions:
    def check(self):
        """ Perform sanity check on module """
        for f in self.Functions:
            f.check()


class Function:
    """
      Function definition. Contains an entry block.
    """
    def __init__(self, name):
        self.name = name
        self.entry = Block('{}_entry'.format(name))
        self.entry.function = self
        self.epiloog = Block('{}_epilog'.format(name))
        self.epiloog.function = self
        self.epiloog.addInstruction(Terminator())
        self.return_value = Temp('{}_retval'.format(name))
        self.arguments = []
        self.localvars = []

    def __repr__(self):
        args = ','.join(str(a) for a in self.arguments)
        return 'Function {}({})'.format(self.name, args)

    def addBlock(self, bb):
        self.bbs.append(bb)
        bb.function = self

    def removeBlock(self, bb):
        #self.bbs.remove(bb)
        bb.function = None

    def getBlocks(self):
        bbs = [self.entry]
        worklist = [self.entry]
        while worklist:
            b = worklist.pop()
            for sb in b.Successors:
                if sb not in bbs:
                    bbs.append(sb)
                    worklist.append(sb)
        bbs.remove(self.entry)
        if self.epiloog in bbs:
            bbs.remove(self.epiloog)
        bbs.insert(0, self.entry)
        bbs.append(self.epiloog)
        return bbs

    def findBasicBlock(self, name):
        for bb in self.bbs:
            if bb.name == name:
                return bb
        raise KeyError(name)

    Blocks = property(getBlocks)

    @property
    def Entry(self):
        return self.entry

    def check(self):
        for b in self.Blocks:
            b.check()

    def addParameter(self, p):
        assert type(p) is Parameter
        p.num = len(self.arguments)
        self.arguments.append(p)

    def addLocal(self, l):
        assert type(l) is LocalVariable
        self.localvars.append(l)

    def dump(self, indent=''):
        print(indent+str(self))
        for bb in self.Blocks:
            print(indent+'   '+str(bb))
            for ins in bb.Instructions:
                print(indent +'   '*2 + str(ins))


class Block:
    """ 
        Uninterrupted sequence of instructions with a label at the start.
    """
    def __init__(self, name):
        self.name = name
        self.function = None
        self.instructions = []

    parent = property(lambda s: s.function)

    def __repr__(self):
        return 'Block {0}'.format(self.name)

    def addInstruction(self, i):
        i.parent = self
        assert not isinstance(self.LastInstruction, LastStatement)
        self.instructions.append(i)

    def replaceInstruction(self, i1, i2):
        idx = self.instructions.index(i1)
        i1.parent = None
        i1.delete()
        i2.parent = self
        self.instructions[idx] = i2

    def removeInstruction(self, i):
        i.parent = None
        i.delete()
        self.instructions.remove(i)

    @property
    def Instructions(self):
        return self.instructions

    @property
    def LastInstruction(self):
        if not self.Empty:
            return self.instructions[-1]

    @property
    def Empty(self):
        return len(self.instructions) == 0

    @property
    def FirstInstruction(self):
        return self.instructions[0]

    def getSuccessors(self):
        if not self.Empty:
            return self.LastInstruction.Targets
        return []
    Successors = property(getSuccessors)

    def getPredecessors(self):
        preds = []
        for bb in self.parent.Blocks:
            if self in bb.Successors:
                preds.append(bb)
        return preds
    Predecessors = property(getPredecessors)

    def precedes(self, other):
        raise NotImplementedError()

    def check(self):
        assert isinstance(self.LastInstruction, LastStatement)
        for i in self.instructions[:-1]:
            assert not isinstance(i, LastStatement)


# Instructions:
class Term:
    def __init__(self, x):
        self.x = x

def match_tree(tree, pattern):
    if type(pattern) is Term:
        return True, {pattern: tree}
    elif type(pattern) is Binop and type(tree) is Binop and tree.operation == pattern.operation:
        res_a, mp_a = match_tree(tree.a, pattern.a)
        res_b, mp_b = match_tree(tree.b, pattern.b)
        assert not (mp_a.keys() & mp_b.keys())
        mp_a.update(mp_b)
        return res_a and res_b, mp_a
    elif type(pattern) is Const and type(tree) is Const and pattern.value == tree.value:
        return True, {}
    else:
        return False, {}


class Expression:
    pass


class Const(Expression):
    def __init__(self, value):
        self.value = value

    def __repr__(self):
        return 'Const {}'.format(self.value)


# Function calling:
class Call(Expression):
    def __init__(self, f, arguments):
        self.f = f
        assert type(f) is Function
        self.arguments = arguments

    def __repr__(self):
        args = ', '.join([str(arg) for arg in self.arguments])
        return '{}({})'.format(self.f.name, args)


# Data operations
class Binop(Expression):
    ops = ['+', '-', '*', '/', '|', '&', '<<', '>>']
    def __init__(self, value1, operation, value2):
        assert operation in Binop.ops
        self.a = value1
        self.b = value2
        self.operation = operation

    def __repr__(self):
        a, b = self.a, self.b
        return '({} {} {})'.format(a, self.operation, b)


def Add(a, b):
    """ Convenience call """
    return Binop(a, '+', b)


def Sub(a, b):
    return Binop(a, '-', b)


def Mul(a, b):
    return Binop(a, '*', b)


def Div(a, b):
    return Binop(a, '/', b)


class Eseq(Expression):
    """ Sequence of instructions where the last is an expression """
    def __init__(self, stmt, e):
        self.stmt = stmt
        self.e = e

    def __repr__(self):
        return '({}, {})'.format(self.stmt, self.e)


class Alloc(Expression):
    """ Allocates space on the stack """
    def __init__(self):
        super().__init__()

    def __repr__(self):
        return 'Alloc'


class Variable(Expression):
    def __init__(self, name):
        self.name = name

    def __repr__(self):
        return 'Var {}'.format(self.name)


class LocalVariable(Variable):
    def __repr__(self):
        return 'Local {}'.format(self.name)


class Parameter(Variable):
    def __repr__(self):
        return 'Param {}'.format(self.name)


class Temp(Expression):
    """ Temporary storage, same as register """
    def __init__(self, name):
        self.name = name

    def __repr__(self):
        return 'TMP_{}'.format(self.name)


class Mem(Expression):
    def __init__(self, e):
        self.e = e

    def __repr__(self):
        return '[{}]'.format(self.e)


class Statement:
    """ Base class for all instructions. """
    pass


class Move(Statement):
    def __init__(self, dst, src):
        self.dst = dst
        self.src = src

    def __repr__(self):
        return '{} = {}'.format(self.dst, self.src)


class Exp(Statement):
    def __init__(self, e):
        self.e = e

    def __repr__(self):
        return '{}'.format(self.e)


# Branching:
class LastStatement(Statement):
    def changeTarget(self, old, new):
        idx = self.Targets.index(old)
        self.Targets[idx] = new


class Terminator(LastStatement):
    """ Instruction that terminates the terminal block """
    def __init__(self):
        self.Targets = []

    def __repr__(self):
        return 'Terminator'


class Jump(LastStatement):
    def __init__(self, target):
        self.Targets = [target]

    def setTarget(self, t):
        self.Targets[0] = t
    target = property(lambda s: s.Targets[0], setTarget)

    def __repr__(self):
        return 'JUMP {}'.format(self.target.name)


class CJump(LastStatement):
    conditions = ['==', '<', '>', '>=', '<=', '!=']
    def __init__(self, a, cond, b, lab_yes, lab_no):
        assert cond in CJump.conditions 
        self.a = a
        self.cond = cond
        self.b = b
        self.Targets = [lab_yes, lab_no]

    lab_yes = property(lambda s: s.Targets[0])
    lab_no = property(lambda s: s.Targets[1])

    def __repr__(self):
        return 'IF {} {} {} THEN {} ELSE {}'.format(self.a, self.cond, self.b, self.lab_yes, self.lab_no)


# Constructing IR:

class NamedClassGenerator:
   def __init__(self, prefix, cls):
      self.prefix = prefix
      self.cls = cls
      def NumGen():
         a = 0
         while True:
            yield a
            a = a + 1
      self.nums = NumGen()

   def gen(self, prefix=None):
      if not prefix:
         prefix = self.prefix
      return self.cls('{0}{1}'.format(prefix, self.nums.__next__()))


class Builder:
    """ Base class for ir code generators """
    def __init__(self):
        self.prepare()

    def prepare(self):
        self.newTemp = NamedClassGenerator('reg', Temp).gen
        self.newBlock2 = NamedClassGenerator('block', Block).gen
        self.bb = None
        self.m = None
        self.fn = None
        self.loc = None

    # Helpers:
    def setModule(self, m):
        self.m = m

    def newFunction(self, name):
        f = Function(name)
        self.m.addFunc(f)
        return f

    def newBlock(self):
        assert self.fn
        b = self.newBlock2()
        b.function = self.fn
        return b

    def setFunction(self, f):
        self.fn = f
        self.bb = f.entry if f else None

    def setBlock(self, b):
        self.bb = b

    def setLoc(self, l):
        self.loc = l

    def emit(self, i):
        assert isinstance(i, Statement)
        i.debugLoc = self.loc
        if not self.bb:
            raise Exception('No basic block')
        self.bb.addInstruction(i)