changeset 1219:9fac28d80fb7

plugin_JB - removed FILT and BUFFER_REPEAT, added Registers
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 22 Sep 2010 13:31:31 -0400
parents 5d1b5906151c
children 35fb6e9713d2
files doc/v2_planning/arch_src/plugin_JB.py doc/v2_planning/arch_src/plugin_JB_main.py
diffstat 2 files changed, 139 insertions(+), 87 deletions(-) [+]
line wrap: on
line diff
--- a/doc/v2_planning/arch_src/plugin_JB.py	Wed Sep 22 13:00:39 2010 -0400
+++ b/doc/v2_planning/arch_src/plugin_JB.py	Wed Sep 22 13:31:31 2010 -0400
@@ -10,28 +10,29 @@
     dataset = Dataset(numpy.random.RandomState(123).randn(13,1))
     pca = PCA_Analysis()
     pca_batchsize=1000
+    
+    reg = Registers()
 
     # define the control-flow of the algorithm
     train_pca = SEQ([
-        BUFFER_REPEAT(pca_batchsize, CALL(dataset.next)), 
-        FILT(pca.analyze)])
+        REPEAT(pca_batchsize, CALL(dataset.next, store_to=reg('x'))), 
+        CALL(pca.analyze, reg('x'))])
 
     # run the program
     train_pca.run()
 
-The CALL, SEQ, FILT, and BUFFER_REPEAT are control-flow elements. The control-flow elements I
+The CALL, SEQ, and REPEAT are control-flow elements. The control-flow elements I
 defined so far are:
 
 - CALL - a basic statement, just calls a python function
-- FILT - like call, but passes the return value of the last CALL or FILT to the python function
 - SEQ - a sequence of elements to run in order
 - REPEAT - do something N times (and return None or maybe the last CALL?)
-- BUFFER_REPEAT - do something N times and accumulate the return value from each iter
 - LOOP - do something an infinite number of times
 - CHOOSE - like a switch statement (should rename to SWITCH)
 - WEAVE - interleave execution of multiple control-flow elements
 - POPEN - launch a process and return its status when it's complete
 - PRINT - a shortcut for CALL(print_obj)
+- SPAWN - run a program fragment asynchronously in another process
 
 
 We don't have many requirements per-se for the architecture, but I think this design respects
@@ -95,14 +96,14 @@
     """
 
     # subclasses should override these methods:
-    def start(self, arg):
+    def start(self):
         pass
     def step(self):
         pass
 
     # subclasses should typically not override these:
-    def run(self, arg=None, n_steps=float('inf')):
-        self.start(arg)
+    def run(self, n_steps=float('inf')):
+        self.start()
         i = 0
         r = self.step()
         while r is INCOMPLETE:
@@ -162,20 +163,15 @@
         self.fn = fn
         self.args = args
         self.kwargs=kwargs
-        self.use_start_arg = kwargs.pop('use_start_arg', False)
-    def start(self, arg):
-        self.start_arg = arg
+    def start(self):
         self.finished = False
         return self
     def step(self):
         assert not self.finished
         self.finished = True
-        if self.use_start_arg:
-            if self.args:
-                raise TypeError('cant get positional args both ways')
-            return self.fn(self.start_arg, **self.kwargs)
-        else:
-            return self.fn(*self.args, **self.kwargs)
+        fn_rval = self.fn(*self.lookup_args(), **self.lookup_kwargs())
+        if '_set' in self.kwargs:
+            self.kwargs['_set'].set(fn_rval)
     def __getstate__(self):
         rval = dict(self.__dict__)
         if type(self.fn) is type(self.step): #instancemethod
@@ -187,12 +183,24 @@
             dct['fn'] = type(self.step)(*dct.pop('i fn'))
         self.__dict__.update(dct)
 
-def FILT(fn, **kwargs):
-    """
-    Return a CALL object that uses the return value from the previous CALL as the first and
-    only positional argument.
-    """
-    return CALL(fn, use_start_arg=True, **kwargs)
+    def lookup_args(self):
+        rval = []
+        for a in self.args:
+            if isinstance(a, Register):
+                rval.append(a.get())
+            else:
+                rval.append(a)
+        return rval
+    def lookup_kwargs(self):
+        rval = {}
+        for k,v in self.kwargs.iteritems():
+            if k == '_set':
+                continue
+            if isinstance(v, Register):
+                rval[k] = v.get()
+            else:
+                rval[k] = v
+        return rval
 
 def CHOOSE(which, options):
     """
@@ -200,46 +208,60 @@
     """
     raise NotImplementedError()
 
-def LOOP(elements):
+def LOOP(element):
     #TODO: implement a true infinite loop
-    try:
-        iter(elements)
-        return REPEAT(sys.maxint, elements)
-    except TypeError:
-        return REPEAT(sys.maxint, [elements])
+    return REPEAT(sys.maxint, element)
 
 class REPEAT(ELEMENT):
-    def __init__(self, N, elements, pass_rvals=False):
+    def __init__(self, N, element, counter=None):
         self.N = N
-        self.elements = elements
-        self.pass_rvals = pass_rvals
+        if not isinstance(element, ELEMENT):
+            raise TypeError(element)
+        self.element = element
+        self.counter = counter
 
     #TODO: check for N being callable
-    def start(self, arg):
+    def start(self):
         self.n = 0   #loop iteration
-        self.idx = 0 #element idx
         self.finished = False
-        self.elements[0].start(arg)
+        self.element.start()
+        if self.counter:
+            self.counter.set(0)
+
     def step(self):
         assert not self.finished
-        r = self.elements[self.idx].step()
+        r = self.element.step()
         if r is INCOMPLETE:
             return INCOMPLETE
-        self.idx += 1
-        if self.idx < len(self.elements):
-            self.elements[self.idx].start(r)
-            return INCOMPLETE
         self.n += 1
+        if self.counter:
+            self.counter.set(self.n)
         if self.n < self.N:
-            self.idx = 0
-            self.elements[self.idx].start(r)
+            self.element.start()
             return INCOMPLETE
         else:
             self.finished = True
             return r
 
-def SEQ(elements):
-    return REPEAT(1, elements)
+class SEQ(ELEMENT):
+    def __init__(self, elements):
+        self.elements = list(elements)
+    def start(self):
+        if len(self.elements):
+            self.elements[0].start()
+            self.pos = 0
+        self.finished = False
+    def step(self):
+        if self.pos == len(self.elements):
+            self.finished=True
+            return
+        r = self.elements[self.pos].step()
+        if r is INCOMPLETE:
+            return r
+        self.pos += 1
+        if self.pos < len(self.elements):
+            self.elements[self.pos].start()
+        return INCOMPLETE
 
 class WEAVE(ELEMENT):
     """
@@ -253,9 +275,9 @@
             self.n_required = len(elements)
         else:
             self.n_required = n_required
-    def start(self, arg):
+    def start(self):
         for el in self.elements:
-            el.start(arg)
+            el.start()
         self.elem_finished = [0] * len(self.elements)
         self.idx = 0
         self.finished= False 
@@ -289,7 +311,7 @@
 class POPEN(ELEMENT):
     def __init__(self, args):
         self.args = args
-    def start(self, arg):
+    def start(self):
         self.p = subprocess.Popen(self.args)
     def step(self):
         r = self.p.poll() 
@@ -305,7 +327,7 @@
     def __init__(self, data, prog):
         self.data = data
         self.prog = prog
-    def start(self, arg):
+    def start(self):
         # pickle the (data, prog) pair
         s = cPickle.dumps((self.data, self.prog))
 
@@ -353,6 +375,17 @@
         return SPAWN.SUCCESS
         #os.close(wpipe)
 
+class Register(object):
+    def __init__(self, registers, key):
+        self.registers = registers
+        self.key = key
+    def set(self, val):
+        self.registers[self.key] = val
+    def get(self):
+        return self.registers[self.key]
+class Registers(dict):
+    def __call__(self, key):
+        return Register(self, key)
 
 def print_obj(obj):
     print obj
@@ -364,3 +397,6 @@
 def importable_fn(d):
     d['new key'] = len(d)
 
+
+if __name__  == '__main__':
+    print 'this is the library file, run "python plugin_JB_main.py"'
--- a/doc/v2_planning/arch_src/plugin_JB_main.py	Wed Sep 22 13:00:39 2010 -0400
+++ b/doc/v2_planning/arch_src/plugin_JB_main.py	Wed Sep 22 13:31:31 2010 -0400
@@ -10,10 +10,10 @@
     def __init__(self, data):
         self.pos = 0
         self.data = data
-    def next(self):
-        rval = self.data[self.pos]
-        self.pos += 1
-        if self.pos == len(self.data):
+    def next(self, n=1):
+        rval = self.data[self.pos:self.pos+n]
+        self.pos += n
+        if self.pos >= len(self.data):
             self.pos = 0
         return rval
     def seek(self, pos):
@@ -28,27 +28,29 @@
     def next_fold(self):
         self.k += 1
         self.data.seek(0) # restart the stream
-    def next(self):
+    def next(self, n=1):
         #TODO: skip the examples that are ommitted in this split
-        return self.data.next()
+        return self.data.next(n)
     def init_test(self):
         pass
-    def next_test(self):
-        return self.data.next()
+    def next_test(self, n=1):
+        return self.data.next(n)
     def test_size(self):
         return 5
     def store_scores(self, scores):
         self.scores[self.k] = scores
 
-    def prog(self, clear, train, test):
-        return REPEAT(self.K, [
+    def prog(self, clear, train, test, test_data_reg, test_counter_reg, test_scores_reg):
+        return REPEAT(self.K, SEQ([
             CALL(self.next_fold),
             clear,
             train,
             CALL(self.init_test),
-            BUFFER_REPEAT(self.test_size(),
-                SEQ([ CALL(self.next_test), test])),
-            FILT(self.store_scores) ])
+            REPEAT(self.test_size(), SEQ([
+                CALL(self.next_test, _set=test_data_reg), 
+                test]),
+                counter=test_counter_reg),
+            CALL(self.store_scores, test_scores_reg)]))
 
 class PCA_Analysis(object):
     def __init__(self):
@@ -93,8 +95,8 @@
         return l[0]
 
     print WEAVE(1, [
-        BUFFER_REPEAT(3,CALL(f,1)),
-        BUFFER_REPEAT(5,CALL(f,1)),
+        REPEAT(3,CALL(f,1)),
+        REPEAT(5,CALL(f,1)),
         ]).run()
 
 def main_weave_popen():
@@ -104,9 +106,9 @@
     p = WEAVE(2,[
         SEQ([POPEN(['sleep', '5']), PRINT('done 1')]),
         SEQ([POPEN(['sleep', '10']), PRINT('done 2')]),
-        LOOP([ 
+        LOOP(SEQ([ 
             CALL(print_obj, 'polling...'),
-            CALL(time.sleep, 1)])])
+            CALL(time.sleep, 1)]))])
     # The LOOP would forever if the WEAVE were not configured to stop after 2 of its elements
     # complete.
 
@@ -120,15 +122,15 @@
     data1 = {0:"blah data1"}
     data2 = {1:"foo data2"}
     p = WEAVE(2,[
-        SPAWN(data1, REPEAT(3, [
+        SPAWN(data1, REPEAT(3, SEQ([
             CALL(importable_fn, data1), 
-            PRINT("hello from 1")])),
-        SPAWN(data2, REPEAT(1, [
+            PRINT("hello from 1")]))),
+        SPAWN(data2, REPEAT(1, SEQ([
             CALL(importable_fn, data2), 
-            PRINT("hello from 2")])),
-        LOOP([ 
+            PRINT("hello from 2")]))),
+        LOOP(SEQ([ 
             CALL(print_obj, 'polling...'),
-            CALL(time.sleep, 0.5)])])
+            CALL(time.sleep, 0.5)]))])
     print 'BEFORE'
     print data1
     print data2
@@ -148,6 +150,7 @@
     layer1 = Layer(w=4)
     layer2 = Layer(w=3)
     kf = KFold(dataset, K=10)
+    reg = Registers()
 
     pca_batchsize=1000
     cd_batchsize = 5
@@ -157,19 +160,19 @@
     # create algorithm
 
     train_pca = SEQ([
-        BUFFER_REPEAT(pca_batchsize, CALL(kf.next)), 
-        FILT(pca.analyze)])
+        CALL(kf.next, pca_batchsize, _set=reg('x')), 
+        CALL(pca.analyze, reg('x'))])
 
-    train_layer1 = REPEAT(n_cd_updates_layer1, [
-        BUFFER_REPEAT(cd_batchsize, CALL(kf.next)),
-        FILT(pca.filt), 
-        FILT(cd1_update, layer=layer1, lr=.01)])
+    train_layer1 = REPEAT(n_cd_updates_layer1, SEQ([
+        CALL(kf.next, cd_batchsize, _set=reg('x')),
+        CALL(pca.filt, reg('x'), _set=reg('x')), 
+        CALL(cd1_update, reg('x'), layer=layer1, lr=.01)]))
 
-    train_layer2 = REPEAT(n_cd_updates_layer2, [
-        BUFFER_REPEAT(cd_batchsize, CALL(kf.next)),
-        FILT(pca.filt), 
-        FILT(layer1.filt),
-        FILT(cd1_update, layer=layer2, lr=.01)])
+    train_layer2 = REPEAT(n_cd_updates_layer2, SEQ([
+        CALL(kf.next, cd_batchsize, _set=reg('x')),
+        CALL(pca.filt, reg('x'), _set=reg('x')), 
+        CALL(layer1.filt, reg('x'), _set=reg('x')),
+        CALL(cd1_update, reg('x'), layer=layer2, lr=.01)]))
 
     kfold_prog = kf.prog(
             clear = SEQ([   # FRAGMENT 1: this bit is the reset/clear stage
@@ -181,14 +184,17 @@
                 train_pca,
                 WEAVE(1, [    # Silly example of how to do debugging / loggin with WEAVE
                     train_layer1, 
-                    LOOP(CALL(print_obj_attr, layer1, 'w'))]),
+                    LOOP(PRINT(reg('x')))]),
                 train_layer2,
                 ]),
             test=SEQ([
-                FILT(pca.filt),       # may want to allow this SEQ to be 
-                FILT(layer1.filt),    # optimized into a shorter one that
-                FILT(layer2.filt),    # compiles these calls together with 
-                FILT(numpy.mean)]))   # Theano
+                CALL(pca.filt, reg('testx'), _set=reg('x')),  
+                CALL(layer1.filt, reg('x'), _set=reg('x')),
+                CALL(layer2.filt, reg('x'), _set=reg('x')),
+                CALL(numpy.mean, reg('x'), _set=reg('score'))]),
+            test_data_reg=reg('testx'),
+            test_counter_reg=reg('i'),
+            test_scores_reg=reg('score'))
 
     pkg1 = dict(prog=kfold_prog, kf=kf)
     pkg2 = copy.deepcopy(pkg1)       # programs can be copied
@@ -206,4 +212,14 @@
 
 
 if __name__ == '__main__':
+    try:
+        sys.argv[1]
+    except:
+        print """You have to tell which main function to use, try:
+    - python plugin_JB_main.py 'main_kfold_dbn()'
+    - python plugin_JB_main.py 'main_weave()'
+    - python plugin_JB_main.py 'main_weave_popen()'
+    - python plugin_JB_main.py 'main_spawn()'
+        """
+        sys.exit(1)
     sys.exit(eval(sys.argv[1]))