comparison doc/v2_planning/arch_src/plugin_JB_main.py @ 1219:9fac28d80fb7

plugin_JB - removed FILT and BUFFER_REPEAT, added Registers
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 22 Sep 2010 13:31:31 -0400
parents 478bb1f8215c
children
comparison
equal deleted inserted replaced
1218:5d1b5906151c 1219:9fac28d80fb7
8 8
9 class Dataset(object): 9 class Dataset(object):
10 def __init__(self, data): 10 def __init__(self, data):
11 self.pos = 0 11 self.pos = 0
12 self.data = data 12 self.data = data
13 def next(self): 13 def next(self, n=1):
14 rval = self.data[self.pos] 14 rval = self.data[self.pos:self.pos+n]
15 self.pos += 1 15 self.pos += n
16 if self.pos == len(self.data): 16 if self.pos >= len(self.data):
17 self.pos = 0 17 self.pos = 0
18 return rval 18 return rval
19 def seek(self, pos): 19 def seek(self, pos):
20 self.pos = pos 20 self.pos = pos
21 21
26 self.scores = [None]*K 26 self.scores = [None]*K
27 self.K = K 27 self.K = K
28 def next_fold(self): 28 def next_fold(self):
29 self.k += 1 29 self.k += 1
30 self.data.seek(0) # restart the stream 30 self.data.seek(0) # restart the stream
31 def next(self): 31 def next(self, n=1):
32 #TODO: skip the examples that are ommitted in this split 32 #TODO: skip the examples that are ommitted in this split
33 return self.data.next() 33 return self.data.next(n)
34 def init_test(self): 34 def init_test(self):
35 pass 35 pass
36 def next_test(self): 36 def next_test(self, n=1):
37 return self.data.next() 37 return self.data.next(n)
38 def test_size(self): 38 def test_size(self):
39 return 5 39 return 5
40 def store_scores(self, scores): 40 def store_scores(self, scores):
41 self.scores[self.k] = scores 41 self.scores[self.k] = scores
42 42
43 def prog(self, clear, train, test): 43 def prog(self, clear, train, test, test_data_reg, test_counter_reg, test_scores_reg):
44 return REPEAT(self.K, [ 44 return REPEAT(self.K, SEQ([
45 CALL(self.next_fold), 45 CALL(self.next_fold),
46 clear, 46 clear,
47 train, 47 train,
48 CALL(self.init_test), 48 CALL(self.init_test),
49 BUFFER_REPEAT(self.test_size(), 49 REPEAT(self.test_size(), SEQ([
50 SEQ([ CALL(self.next_test), test])), 50 CALL(self.next_test, _set=test_data_reg),
51 FILT(self.store_scores) ]) 51 test]),
52 counter=test_counter_reg),
53 CALL(self.store_scores, test_scores_reg)]))
52 54
53 class PCA_Analysis(object): 55 class PCA_Analysis(object):
54 def __init__(self): 56 def __init__(self):
55 self.clear() 57 self.clear()
56 58
91 print l 93 print l
92 l[0] += a 94 l[0] += a
93 return l[0] 95 return l[0]
94 96
95 print WEAVE(1, [ 97 print WEAVE(1, [
96 BUFFER_REPEAT(3,CALL(f,1)), 98 REPEAT(3,CALL(f,1)),
97 BUFFER_REPEAT(5,CALL(f,1)), 99 REPEAT(5,CALL(f,1)),
98 ]).run() 100 ]).run()
99 101
100 def main_weave_popen(): 102 def main_weave_popen():
101 # Uses weave and Popen to demonstrate the control of a program with some asynchronous 103 # Uses weave and Popen to demonstrate the control of a program with some asynchronous
102 # parallelism 104 # parallelism
103 105
104 p = WEAVE(2,[ 106 p = WEAVE(2,[
105 SEQ([POPEN(['sleep', '5']), PRINT('done 1')]), 107 SEQ([POPEN(['sleep', '5']), PRINT('done 1')]),
106 SEQ([POPEN(['sleep', '10']), PRINT('done 2')]), 108 SEQ([POPEN(['sleep', '10']), PRINT('done 2')]),
107 LOOP([ 109 LOOP(SEQ([
108 CALL(print_obj, 'polling...'), 110 CALL(print_obj, 'polling...'),
109 CALL(time.sleep, 1)])]) 111 CALL(time.sleep, 1)]))])
110 # The LOOP would forever if the WEAVE were not configured to stop after 2 of its elements 112 # The LOOP would forever if the WEAVE were not configured to stop after 2 of its elements
111 # complete. 113 # complete.
112 114
113 p.run() 115 p.run()
114 # Note that the program can be run multiple times... 116 # Note that the program can be run multiple times...
118 # illustate the use of SPAWN to drive a set of control programs 120 # illustate the use of SPAWN to drive a set of control programs
119 # in other processes 121 # in other processes
120 data1 = {0:"blah data1"} 122 data1 = {0:"blah data1"}
121 data2 = {1:"foo data2"} 123 data2 = {1:"foo data2"}
122 p = WEAVE(2,[ 124 p = WEAVE(2,[
123 SPAWN(data1, REPEAT(3, [ 125 SPAWN(data1, REPEAT(3, SEQ([
124 CALL(importable_fn, data1), 126 CALL(importable_fn, data1),
125 PRINT("hello from 1")])), 127 PRINT("hello from 1")]))),
126 SPAWN(data2, REPEAT(1, [ 128 SPAWN(data2, REPEAT(1, SEQ([
127 CALL(importable_fn, data2), 129 CALL(importable_fn, data2),
128 PRINT("hello from 2")])), 130 PRINT("hello from 2")]))),
129 LOOP([ 131 LOOP(SEQ([
130 CALL(print_obj, 'polling...'), 132 CALL(print_obj, 'polling...'),
131 CALL(time.sleep, 0.5)])]) 133 CALL(time.sleep, 0.5)]))])
132 print 'BEFORE' 134 print 'BEFORE'
133 print data1 135 print data1
134 print data2 136 print data2
135 p.run() 137 p.run()
136 print 'AFTER' 138 print 'AFTER'
146 dataset = Dataset(numpy.random.RandomState(123).randn(13,1)) 148 dataset = Dataset(numpy.random.RandomState(123).randn(13,1))
147 pca = PCA_Analysis() 149 pca = PCA_Analysis()
148 layer1 = Layer(w=4) 150 layer1 = Layer(w=4)
149 layer2 = Layer(w=3) 151 layer2 = Layer(w=3)
150 kf = KFold(dataset, K=10) 152 kf = KFold(dataset, K=10)
153 reg = Registers()
151 154
152 pca_batchsize=1000 155 pca_batchsize=1000
153 cd_batchsize = 5 156 cd_batchsize = 5
154 n_cd_updates_layer1 = 10 157 n_cd_updates_layer1 = 10
155 n_cd_updates_layer2 = 10 158 n_cd_updates_layer2 = 10
156 159
157 # create algorithm 160 # create algorithm
158 161
159 train_pca = SEQ([ 162 train_pca = SEQ([
160 BUFFER_REPEAT(pca_batchsize, CALL(kf.next)), 163 CALL(kf.next, pca_batchsize, _set=reg('x')),
161 FILT(pca.analyze)]) 164 CALL(pca.analyze, reg('x'))])
162 165
163 train_layer1 = REPEAT(n_cd_updates_layer1, [ 166 train_layer1 = REPEAT(n_cd_updates_layer1, SEQ([
164 BUFFER_REPEAT(cd_batchsize, CALL(kf.next)), 167 CALL(kf.next, cd_batchsize, _set=reg('x')),
165 FILT(pca.filt), 168 CALL(pca.filt, reg('x'), _set=reg('x')),
166 FILT(cd1_update, layer=layer1, lr=.01)]) 169 CALL(cd1_update, reg('x'), layer=layer1, lr=.01)]))
167 170
168 train_layer2 = REPEAT(n_cd_updates_layer2, [ 171 train_layer2 = REPEAT(n_cd_updates_layer2, SEQ([
169 BUFFER_REPEAT(cd_batchsize, CALL(kf.next)), 172 CALL(kf.next, cd_batchsize, _set=reg('x')),
170 FILT(pca.filt), 173 CALL(pca.filt, reg('x'), _set=reg('x')),
171 FILT(layer1.filt), 174 CALL(layer1.filt, reg('x'), _set=reg('x')),
172 FILT(cd1_update, layer=layer2, lr=.01)]) 175 CALL(cd1_update, reg('x'), layer=layer2, lr=.01)]))
173 176
174 kfold_prog = kf.prog( 177 kfold_prog = kf.prog(
175 clear = SEQ([ # FRAGMENT 1: this bit is the reset/clear stage 178 clear = SEQ([ # FRAGMENT 1: this bit is the reset/clear stage
176 CALL(pca.clear), 179 CALL(pca.clear),
177 CALL(layer1.clear), 180 CALL(layer1.clear),
179 ]), 182 ]),
180 train = SEQ([ 183 train = SEQ([
181 train_pca, 184 train_pca,
182 WEAVE(1, [ # Silly example of how to do debugging / loggin with WEAVE 185 WEAVE(1, [ # Silly example of how to do debugging / loggin with WEAVE
183 train_layer1, 186 train_layer1,
184 LOOP(CALL(print_obj_attr, layer1, 'w'))]), 187 LOOP(PRINT(reg('x')))]),
185 train_layer2, 188 train_layer2,
186 ]), 189 ]),
187 test=SEQ([ 190 test=SEQ([
188 FILT(pca.filt), # may want to allow this SEQ to be 191 CALL(pca.filt, reg('testx'), _set=reg('x')),
189 FILT(layer1.filt), # optimized into a shorter one that 192 CALL(layer1.filt, reg('x'), _set=reg('x')),
190 FILT(layer2.filt), # compiles these calls together with 193 CALL(layer2.filt, reg('x'), _set=reg('x')),
191 FILT(numpy.mean)])) # Theano 194 CALL(numpy.mean, reg('x'), _set=reg('score'))]),
195 test_data_reg=reg('testx'),
196 test_counter_reg=reg('i'),
197 test_scores_reg=reg('score'))
192 198
193 pkg1 = dict(prog=kfold_prog, kf=kf) 199 pkg1 = dict(prog=kfold_prog, kf=kf)
194 pkg2 = copy.deepcopy(pkg1) # programs can be copied 200 pkg2 = copy.deepcopy(pkg1) # programs can be copied
195 201
196 try: 202 try:
204 pkg['prog'].run() 210 pkg['prog'].run()
205 print pkg['kf'].scores 211 print pkg['kf'].scores
206 212
207 213
208 if __name__ == '__main__': 214 if __name__ == '__main__':
215 try:
216 sys.argv[1]
217 except:
218 print """You have to tell which main function to use, try:
219 - python plugin_JB_main.py 'main_kfold_dbn()'
220 - python plugin_JB_main.py 'main_weave()'
221 - python plugin_JB_main.py 'main_weave_popen()'
222 - python plugin_JB_main.py 'main_spawn()'
223 """
224 sys.exit(1)
209 sys.exit(eval(sys.argv[1])) 225 sys.exit(eval(sys.argv[1]))