diff python/pyburg.py @ 322:44f336460c2a

Half of use of burg spec for arm
author Windel Bouwman
date Mon, 27 Jan 2014 19:58:07 +0100
parents 8c569fbe60e4
children e9fe6988497c
line wrap: on
line diff
--- a/python/pyburg.py	Sun Jan 19 18:48:45 2014 +0100
+++ b/python/pyburg.py	Mon Jan 27 19:58:07 2014 +0100
@@ -1,39 +1,52 @@
 #!/usr/bin/python
 
 """
-Bottom up rewrite generator.
+Bottom up rewrite generator
+---------------------------
 
 This script takes as input a description of patterns and outputs a
 matcher class that can match trees given the patterns.
 
-    Patterns are specified as follows:
+Patterns are specified as follows::
+
      reg -> ADDI32(reg, reg) 2 (. add NT0 NT1 .)
      reg -> MULI32(reg, reg) 3 (. .)
-    or a multiply add:
-     reg -> ADDI32(MULI32(reg, reg), reg) 4 (. muladd $1, $2, $3 .)
-    The general specification pattern is:
-     [result] -> [tree] [cost] [template code]
+
+or a multiply add::
+
+    reg -> ADDI32(MULI32(reg, reg), reg) 4 (. muladd $1, $2, $3 .)
+
+The general specification pattern is::
+
+    [result] -> [tree] [cost] [template code]
 
 Trees
 -----
 
 A tree is described using parenthesis notation. For example a node X with
 three child nodes is described as:
+
      X(a, b, b)
+
 Trees can be nested:
+
      X(Y(a, a), a)
+
 The 'a' in the example above indicates an open connection to a next tree
 pattern.
 
 
 In the example above 'reg' is a non-terminal. ADDI32 is a terminal. non-terminals
 cannot have child nodes. A special case occurs in this case:
-reg -> rc
+
+    reg -> rc
+
 where 'rc' is a non-terminal. This is an example of a chain rule. Chain rules
 can be used to allow several variants of non-terminals.
 
 The generated matcher uses dynamic programming to find the best match of the
 tree. This strategy consists of two steps:
+
   - label: During this phase the given tree is traversed in a bottom up way.
     each node is labelled with a possible matching rule and the corresponding cost.
   - select: In this step, the tree is traversed again, selecting at each point
@@ -43,6 +56,8 @@
 
 import sys
 import os
+import io
+import types
 import argparse
 from ppci import Token
 from pyyacc import ParserException, EOF
@@ -68,16 +83,24 @@
             ]
 
         lines = txt.split('\n')
+        header_lines = []
 
         def tokenize():
+            section = 0
             for line in lines:
                 line = line.strip()
                 if not line:
                     continue  # Skip empty lines
                 elif line == '%%':
+                    section += 1
+                    if section == 1:
+                        yield Token('header', header_lines)
                     yield Token('%%', '%%')
                 else:
-                    yield from baselex.tokenize(tok_spec, line)
+                    if section == 0:
+                        header_lines.append(line)
+                    else:
+                        yield from baselex.tokenize(tok_spec, line)
             yield Token(EOF, EOF)
         self.tokens = tokenize()
         self.token = self.tokens.__next__()
@@ -188,6 +211,8 @@
 
         self.print('#!/usr/bin/python')
         self.print('from tree import Tree, BaseMatcher, State')
+        for header in self.system.header_lines:
+            self.print(header)
         self.print()
         self.print('class Matcher(BaseMatcher):')
         self.print('    def __init__(self):')
@@ -209,8 +234,13 @@
                 args = ', ' + ', '.join('nt{}'.format(x) for x in range(rule.num_nts))
             else:
                 args = ''
-            self.print('    def P{}(self{}):'.format(rule.nr, args))
-            self.print('        {}'.format(rule.template))
+            self.print('    def P{}(self, tree{}):'.format(rule.nr, args))
+            template = rule.template
+            template = template.replace('$$', 'tree')
+            for i in range(rule.num_nts):
+                template = template.replace('${}'.format(i+1), 'nt{}'.format(i))
+            for t in template.split(';'):
+                self.print('        {}'.format(t.strip()))
         self.emit_state()
         self.print('    def gen(self, tree):')
         self.print('        self.burm_label(tree)')
@@ -220,18 +250,18 @@
 
     def emit_record(self, rule, state_var):
         # TODO: check for rules fullfilled (by not using 999999)
-        self.print('        nts = self.nts({})'.format(rule.nr))
-        self.print('        kids = self.kids(tree, {})'.format(rule.nr))
-        self.print('        c = sum(x.state.get_cost(y) for x, y in zip(kids, nts)) + {}'.format(rule.cost))
-        self.print('        tree.state.set_cost("{}", c, {})'.format(rule.non_term, rule.nr))
+        self.print('            nts = self.nts({})'.format(rule.nr))
+        self.print('            kids = self.kids(tree, {})'.format(rule.nr))
+        self.print('            c = sum(x.state.get_cost(y) for x, y in zip(kids, nts)) + {}'.format(rule.cost))
+        self.print('            tree.state.set_cost("{}", c, {})'.format(rule.non_term, rule.nr))
         for cr in self.system.symbols[rule.non_term].chain_rules:
-            self.print('        # Chain rule: {}'.format(cr))
-            self.print('        tree.state.set_cost("{}", c + {}, {})'.format(cr.non_term, cr.cost, cr.nr))
+            self.print('            # Chain rule: {}'.format(cr))
+            self.print('            tree.state.set_cost("{}", c + {}, {})'.format(cr.non_term, cr.cost, cr.nr))
 
     def emit_state(self):
         """ Emit a function that assigns a new state to a node """
         self.print('    def burm_state(self, tree):')
-        self.print('     tree.state = State()')
+        self.print('        tree.state = State()')
         for term in self.system.terminals:
             self.emitcase(term)
         self.print()
@@ -240,7 +270,7 @@
         rules = [rule for rule in self.system.rules if rule.tree.name == term]
         for rule in rules:
             condition = self.emittest(rule.tree, 'tree')
-            self.print('     if {}:'.format(condition))
+            self.print('        if {}:'.format(condition))
             self.emit_record(rule, 'state')
 
     def compute_kids(self, t, root_name):
@@ -278,6 +308,16 @@
         default=sys.stdout)
     return parser
 
+def load_as_module(filename):
+    """ Load a parser spec file, generate LR tables and create module """
+    ob = io.StringIO()
+    args = argparse.Namespace(source=open(filename), output=ob)
+    main(args)
+
+    matcher_mod = types.ModuleType('generated_matcher')
+    exec(ob.getvalue(), matcher_mod.__dict__)
+    return matcher_mod
+
 
 def main(args):
     src = args.source.read()