diff doc/v2_planning/arch_src/plugin_JB_main.py @ 1219:9fac28d80fb7

plugin_JB - removed FILT and BUFFER_REPEAT, added Registers
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 22 Sep 2010 13:31:31 -0400
parents 478bb1f8215c
children
line wrap: on
line diff
--- a/doc/v2_planning/arch_src/plugin_JB_main.py	Wed Sep 22 13:00:39 2010 -0400
+++ b/doc/v2_planning/arch_src/plugin_JB_main.py	Wed Sep 22 13:31:31 2010 -0400
@@ -10,10 +10,10 @@
     def __init__(self, data):
         self.pos = 0
         self.data = data
-    def next(self):
-        rval = self.data[self.pos]
-        self.pos += 1
-        if self.pos == len(self.data):
+    def next(self, n=1):
+        rval = self.data[self.pos:self.pos+n]
+        self.pos += n
+        if self.pos >= len(self.data):
             self.pos = 0
         return rval
     def seek(self, pos):
@@ -28,27 +28,29 @@
     def next_fold(self):
         self.k += 1
         self.data.seek(0) # restart the stream
-    def next(self):
+    def next(self, n=1):
         #TODO: skip the examples that are ommitted in this split
-        return self.data.next()
+        return self.data.next(n)
     def init_test(self):
         pass
-    def next_test(self):
-        return self.data.next()
+    def next_test(self, n=1):
+        return self.data.next(n)
     def test_size(self):
         return 5
     def store_scores(self, scores):
         self.scores[self.k] = scores
 
-    def prog(self, clear, train, test):
-        return REPEAT(self.K, [
+    def prog(self, clear, train, test, test_data_reg, test_counter_reg, test_scores_reg):
+        return REPEAT(self.K, SEQ([
             CALL(self.next_fold),
             clear,
             train,
             CALL(self.init_test),
-            BUFFER_REPEAT(self.test_size(),
-                SEQ([ CALL(self.next_test), test])),
-            FILT(self.store_scores) ])
+            REPEAT(self.test_size(), SEQ([
+                CALL(self.next_test, _set=test_data_reg), 
+                test]),
+                counter=test_counter_reg),
+            CALL(self.store_scores, test_scores_reg)]))
 
 class PCA_Analysis(object):
     def __init__(self):
@@ -93,8 +95,8 @@
         return l[0]
 
     print WEAVE(1, [
-        BUFFER_REPEAT(3,CALL(f,1)),
-        BUFFER_REPEAT(5,CALL(f,1)),
+        REPEAT(3,CALL(f,1)),
+        REPEAT(5,CALL(f,1)),
         ]).run()
 
 def main_weave_popen():
@@ -104,9 +106,9 @@
     p = WEAVE(2,[
         SEQ([POPEN(['sleep', '5']), PRINT('done 1')]),
         SEQ([POPEN(['sleep', '10']), PRINT('done 2')]),
-        LOOP([ 
+        LOOP(SEQ([ 
             CALL(print_obj, 'polling...'),
-            CALL(time.sleep, 1)])])
+            CALL(time.sleep, 1)]))])
     # The LOOP would forever if the WEAVE were not configured to stop after 2 of its elements
     # complete.
 
@@ -120,15 +122,15 @@
     data1 = {0:"blah data1"}
     data2 = {1:"foo data2"}
     p = WEAVE(2,[
-        SPAWN(data1, REPEAT(3, [
+        SPAWN(data1, REPEAT(3, SEQ([
             CALL(importable_fn, data1), 
-            PRINT("hello from 1")])),
-        SPAWN(data2, REPEAT(1, [
+            PRINT("hello from 1")]))),
+        SPAWN(data2, REPEAT(1, SEQ([
             CALL(importable_fn, data2), 
-            PRINT("hello from 2")])),
-        LOOP([ 
+            PRINT("hello from 2")]))),
+        LOOP(SEQ([ 
             CALL(print_obj, 'polling...'),
-            CALL(time.sleep, 0.5)])])
+            CALL(time.sleep, 0.5)]))])
     print 'BEFORE'
     print data1
     print data2
@@ -148,6 +150,7 @@
     layer1 = Layer(w=4)
     layer2 = Layer(w=3)
     kf = KFold(dataset, K=10)
+    reg = Registers()
 
     pca_batchsize=1000
     cd_batchsize = 5
@@ -157,19 +160,19 @@
     # create algorithm
 
     train_pca = SEQ([
-        BUFFER_REPEAT(pca_batchsize, CALL(kf.next)), 
-        FILT(pca.analyze)])
+        CALL(kf.next, pca_batchsize, _set=reg('x')), 
+        CALL(pca.analyze, reg('x'))])
 
-    train_layer1 = REPEAT(n_cd_updates_layer1, [
-        BUFFER_REPEAT(cd_batchsize, CALL(kf.next)),
-        FILT(pca.filt), 
-        FILT(cd1_update, layer=layer1, lr=.01)])
+    train_layer1 = REPEAT(n_cd_updates_layer1, SEQ([
+        CALL(kf.next, cd_batchsize, _set=reg('x')),
+        CALL(pca.filt, reg('x'), _set=reg('x')), 
+        CALL(cd1_update, reg('x'), layer=layer1, lr=.01)]))
 
-    train_layer2 = REPEAT(n_cd_updates_layer2, [
-        BUFFER_REPEAT(cd_batchsize, CALL(kf.next)),
-        FILT(pca.filt), 
-        FILT(layer1.filt),
-        FILT(cd1_update, layer=layer2, lr=.01)])
+    train_layer2 = REPEAT(n_cd_updates_layer2, SEQ([
+        CALL(kf.next, cd_batchsize, _set=reg('x')),
+        CALL(pca.filt, reg('x'), _set=reg('x')), 
+        CALL(layer1.filt, reg('x'), _set=reg('x')),
+        CALL(cd1_update, reg('x'), layer=layer2, lr=.01)]))
 
     kfold_prog = kf.prog(
             clear = SEQ([   # FRAGMENT 1: this bit is the reset/clear stage
@@ -181,14 +184,17 @@
                 train_pca,
                 WEAVE(1, [    # Silly example of how to do debugging / loggin with WEAVE
                     train_layer1, 
-                    LOOP(CALL(print_obj_attr, layer1, 'w'))]),
+                    LOOP(PRINT(reg('x')))]),
                 train_layer2,
                 ]),
             test=SEQ([
-                FILT(pca.filt),       # may want to allow this SEQ to be 
-                FILT(layer1.filt),    # optimized into a shorter one that
-                FILT(layer2.filt),    # compiles these calls together with 
-                FILT(numpy.mean)]))   # Theano
+                CALL(pca.filt, reg('testx'), _set=reg('x')),  
+                CALL(layer1.filt, reg('x'), _set=reg('x')),
+                CALL(layer2.filt, reg('x'), _set=reg('x')),
+                CALL(numpy.mean, reg('x'), _set=reg('score'))]),
+            test_data_reg=reg('testx'),
+            test_counter_reg=reg('i'),
+            test_scores_reg=reg('score'))
 
     pkg1 = dict(prog=kfold_prog, kf=kf)
     pkg2 = copy.deepcopy(pkg1)       # programs can be copied
@@ -206,4 +212,14 @@
 
 
 if __name__ == '__main__':
+    try:
+        sys.argv[1]
+    except:
+        print """You have to tell which main function to use, try:
+    - python plugin_JB_main.py 'main_kfold_dbn()'
+    - python plugin_JB_main.py 'main_weave()'
+    - python plugin_JB_main.py 'main_weave_popen()'
+    - python plugin_JB_main.py 'main_spawn()'
+        """
+        sys.exit(1)
     sys.exit(eval(sys.argv[1]))