275
|
1
|
|
2 import ir
|
|
3 from itertools import chain
|
|
4
|
|
5 def make(function, frame):
|
|
6 """
|
|
7 Create canonicalized version of the IR-code. This means:
|
|
8 - Calls out of expressions.
|
|
9 - Other things?
|
|
10 """
|
|
11 # Change the tree. This modifies the IR-tree!
|
|
12 # Move all parameters into registers
|
|
13 parmoves = []
|
|
14 for p in function.arguments:
|
|
15 pt = newTemp()
|
|
16 frame.parMap[p] = pt
|
|
17 parmoves.append(ir.Move(pt, frame.argLoc(p.num)))
|
|
18 function.entry.instructions = parmoves + function.entry.instructions
|
|
19
|
|
20 for block in function.Blocks:
|
|
21 for stmt in block.instructions:
|
|
22 rewriteStmt(stmt, frame)
|
|
23 linearize(block)
|
|
24
|
|
25 # Visit all nodes with some function:
|
|
26 # TODO: rewrite into visitor.
|
|
27
|
|
28 # Rewrite rewrites call instructions into Eseq instructions.
|
|
29
|
|
30 def rewriteStmt(stmt, frame):
|
|
31 if isinstance(stmt, ir.Jump):
|
|
32 pass
|
|
33 elif isinstance(stmt, ir.CJump):
|
|
34 stmt.a = rewriteExp(stmt.a, frame)
|
|
35 stmt.b = rewriteExp(stmt.b, frame)
|
|
36 elif isinstance(stmt, ir.Move):
|
|
37 stmt.src = rewriteExp(stmt.src, frame)
|
|
38 stmt.dst = rewriteExp(stmt.dst, frame)
|
|
39 elif isinstance(stmt, ir.Terminator):
|
|
40 pass
|
|
41 elif isinstance(stmt, ir.Exp):
|
|
42 stmt.e = rewriteExp(stmt.e, frame)
|
|
43 else:
|
|
44 raise NotImplementedError('STMT NI: {}'.format(stmt))
|
|
45
|
|
46 newTemp = ir.NamedClassGenerator('canon_reg', ir.Temp).gen
|
|
47
|
|
48 def rewriteExp(exp, frame):
|
|
49 if isinstance(exp, ir.Binop):
|
|
50 exp.a = rewriteExp(exp.a, frame)
|
|
51 exp.b = rewriteExp(exp.b, frame)
|
|
52 return exp
|
|
53 elif isinstance(exp, ir.Const):
|
|
54 return exp
|
|
55 elif isinstance(exp, ir.Temp):
|
|
56 return exp
|
|
57 elif isinstance(exp, ir.Parameter):
|
|
58 return frame.parMap[exp]
|
|
59 elif isinstance(exp, ir.LocalVariable):
|
|
60 offset = frame.allocVar(exp)
|
279
|
61 return ir.Add(frame.fp, ir.Const(offset))
|
275
|
62 elif isinstance(exp, ir.Mem):
|
|
63 exp.e = rewriteExp(exp.e, frame)
|
|
64 return exp
|
|
65 elif isinstance(exp, ir.Call):
|
|
66 exp.arguments = [rewriteExp(p, frame) for p in exp.arguments]
|
|
67 # Rewrite call into eseq:
|
|
68 t = newTemp()
|
|
69 return ir.Eseq(ir.Move(t, exp), t)
|
|
70 else:
|
|
71 raise NotImplementedError('NI: {}'.format(exp))
|
|
72
|
|
73 # The flatten functions pull out seq instructions to the sequence list.
|
|
74
|
|
75 def flattenExp(exp):
|
|
76 if isinstance(exp, ir.Binop):
|
|
77 exp.a, sa = flattenExp(exp.a)
|
|
78 exp.b, sb = flattenExp(exp.b)
|
|
79 return exp, sa + sb
|
|
80 elif isinstance(exp, ir.Temp):
|
|
81 return exp, []
|
|
82 elif isinstance(exp, ir.Const):
|
|
83 return exp, []
|
|
84 elif isinstance(exp, ir.Mem):
|
|
85 exp.e, s = flattenExp(exp.e)
|
|
86 return exp, s
|
|
87 elif isinstance(exp, ir.Eseq):
|
|
88 s = flattenStmt(exp.stmt)
|
|
89 exp.e, se = flattenExp(exp.e)
|
|
90 return exp.e, s + se
|
|
91 elif isinstance(exp, ir.Call):
|
|
92 sp = []
|
|
93 p = []
|
|
94 for p_, sp_ in (flattenExp(p) for p in exp.arguments):
|
|
95 p.append(p_)
|
|
96 sp.extend(sp_)
|
|
97 exp.arguments = p
|
|
98 return exp, sp
|
|
99 else:
|
|
100 raise NotImplementedError('NI: {}'.format(exp))
|
|
101
|
|
102 def flattenStmt(stmt):
|
|
103 if isinstance(stmt, ir.Jump):
|
|
104 return [stmt]
|
|
105 elif isinstance(stmt, ir.CJump):
|
|
106 stmt.a, sa = flattenExp(stmt.a)
|
|
107 stmt.b, sb = flattenExp(stmt.b)
|
|
108 return sa + sb + [stmt]
|
|
109 elif isinstance(stmt, ir.Move):
|
|
110 stmt.dst, sd = flattenExp(stmt.dst)
|
|
111 stmt.src, ss = flattenExp(stmt.src)
|
|
112 return sd + ss + [stmt]
|
|
113 elif isinstance(stmt, ir.Terminator):
|
|
114 return [stmt]
|
|
115 elif isinstance(stmt, ir.Exp):
|
|
116 stmt.e, se = flattenExp(stmt.e)
|
|
117 return se + [stmt]
|
|
118 else:
|
|
119 raise NotImplementedError('STMT NI: {}'.format(stmt))
|
|
120
|
|
121
|
|
122 def linearize(block):
|
|
123 """
|
|
124 Move seq instructions to top and flatten these in an instruction list
|
|
125 """
|
|
126 i = list(flattenStmt(s) for s in block.instructions)
|
|
127 block.instructions = list(chain.from_iterable(i))
|
|
128
|