comparison doc/v2_planning/plugin_greenlet.py @ 1197:a60b3472c4ba

more progress on greenlets
author James Bergstra <bergstrj@iro.umontreal.ca>
date Sun, 19 Sep 2010 23:49:24 -0400
parents e9bb3340a870
children acfd5e747a75
comparison
equal deleted inserted replaced
1196:e9bb3340a870 1197:a60b3472c4ba
1 """plugin_greenlet - draft of library architecture using greenlets""" 1 """plugin_greenlet - draft of library architecture using greenlets"""
2
3
4 """
5
6 - PICKLABLE - algorithms are serializable at all points during execution
7
8 - ITERATOR walks through algorithms with fine granularity
9
10 - COMPONENTS - library provides components on which programs operate
11
12 - ALGORITHMS - library provides algorithms in clean (no hooks) form
13
14 - HOOKS - user can insert print / debug logic with search/replace type calls
15 e.g. prog.find(CALL(cd1_update)).replace_with(SEQ([CALL(cd1_update), CALL(debugfn)]))
16
17 - PRINTING - user can print the 'program code' of an algorithm built from library pieces
18
19 - MODULAR EXPERIMENTS - an experiment object with one (or more?) programs and all of the objects referred to by
20 those programs. It is the preferred type of object to be serialized. The main components of
21 the algorithms should be top-level attributes of the package. This object can be serialized
22 and loaded in another process to implement job migration.
23
24 - OPTIMIZATION - program can be optimized automatically
25 e.g. BUFFER(N, CALL(dataset.next)) can be replaced if dataset.next implements the right
26 attribute/protocol for 'bufferable' or something.
27
28 e.g. SEQ([a,b,c,d]) can be compiled with Theano if sub-sequence is compatible
29
30 - don't need greenlets to get efficiency, the implementations of control flow ops can manage a
31 stack or stack tree in the vm (like greenlets do I think) we don't really need
32 greenlets/stackless I don't think
33
34 """
2 35
3 __license__ = None 36 __license__ = None
4 __copyright__ = None 37 __copyright__ = None
5 38
6 import copy, sys 39 import copy, sys
31 assert isinstance(incoming, tuple) 64 assert isinstance(incoming, tuple)
32 assert len(incoming)==4 65 assert len(incoming)==4
33 return incoming 66 return incoming
34 67
35 def vm_run(prog, *args, **kwargs): 68 def vm_run(prog, *args, **kwargs):
36 69 #TODO: make this into a class with different ways to start the loop
70 # for example, if the (gr, dest, a, kw) tuple is returned,
71 # then a program could be run for N steps and then paused,
72 # saved, and restarted.
73 n_steps = kwargs.pop('n_steps', float('inf'))
37 def vm_loop(gr, dest, a, kw): 74 def vm_loop(gr, dest, a, kw):
38 while True: 75 loop_iter = 0
76 while loop_iter < n_steps:
39 if gr == 'return': 77 if gr == 'return':
40 return a, kw 78 break
41 #print 'vm_loop gr=',gr,'args=',a, 'kwargs=', kw 79 #print 'vm_loop gr=',gr,'args=',a, 'kwargs=', kw
42 gr, dest, a, kw = gr.switch(vm, gr, dest, a, kw) 80 gr, dest, a, kw = gr.switch(vm, gr, dest, a, kw)
43 #print 'gmain incoming', incoming 81 #print 'gmain incoming', incoming
82 loop_iter += 1
83 # permit restarting
84 return gr, dest, a, kw
44 vm = greenlet(vm_loop) 85 vm = greenlet(vm_loop)
45
46 return vm.switch(prog, 'return', args, kwargs) 86 return vm.switch(prog, 'return', args, kwargs)
47 87
48 88 ####################################################
49 def seq(glets): 89 # CONTROL-FLOW CONSTRUCTS
50 return repeat(1, glets) 90
51 91 def SEQ(glets):
52 def repeat(N, glets): 92 return REPEAT(1, glets)
93
94 def REPEAT(N, glets):
53 def repeat_task(vm, gself, dest, args, kwargs): 95 def repeat_task(vm, gself, dest, args, kwargs):
54 while True: 96 while True:
55 for i in xrange(N): 97 for i in xrange(N):
56 for glet in glets: 98 for glet in glets:
57 #print 'repeat_task_i dest=%(dest)s args=%(args)s, kw=%(kwargs)s'%locals() 99 #print 'repeat_task_i dest=%(dest)s args=%(args)s, kw=%(kwargs)s'%locals()
61 assert _gself is gself 103 assert _gself is gself
62 assert _dest is None # instructions can't tell us where to jump 104 assert _dest is None # instructions can't tell us where to jump
63 vm, gself, dest, args, kwargs = vm.switch(dest, None, args, kwargs) 105 vm, gself, dest, args, kwargs = vm.switch(dest, None, args, kwargs)
64 return greenlet(repeat_task) 106 return greenlet(repeat_task)
65 107
66 def choose(which, options): 108 def LOOP(seq):
109 #TODO: implement a true infinite loop
110 try:
111 iter(seq)
112 return REPEAT(sys.maxint, seq)
113 except TypeError:
114 return REPEAT(sys.maxint, [seq])
115
116 def CHOOSE(which, options):
67 raise NotImplementedError() 117 raise NotImplementedError()
68 118
69 def weave(threads): 119 def WEAVE(threads):
70 raise NotImplementedError() 120 def weave_task(vm, gself, dest, args, kwargs):
71 121 # weave works by telling its threads that *it* is the vm
72 def service(fn): 122 # and reporting back to the real vm indirectly
123 while True: # execution of weave is an iteration through this loop
124 # initially broadcast the args and kwargs to all threads
125 all_threads_live = True
126 thread_info = [(t, 'return', args, kwargs) for t in threads]
127 #print 'weave start -------------'
128 while all_threads_live:
129 #print 'weave iter'
130 for i in xrange(len(threads)):
131 t_next, t_dest, t_args, t_kwargs = thread_info[i]
132 #tell the vm we're up to something, but ask it to come right back
133 #print 'weave 1'
134 _ignore = vm.switch(gself, None, (), {})
135
136 # pretend we're the vm_loop and tell the
137 # thread to advance by one and report back to us
138 #print 'weave 2', thread_info[i]
139 thread_info[i] = t_next.switch(gself, t_next, t_dest, t_args, t_kwargs)
140 #print 'weave 3', thread_info[i]
141 if thread_info[i][0] is 'return':
142 #print 'thread has finished', i
143 all_threads_live = False
144
145 # some thread has died so we return control to parent
146 #####print 'weave returning', dest, args, kwargs
147 vm, gself, dest, args, kwargs = vm.switch(dest, None, args, kwargs)
148 #####print 'weave continuing', dest, args, kwargs
149 return greenlet(weave_task)
150
151 def BUFFER(N, glet):
152 def BUFFER_loop(vm, gself, dest, args, kwargs):
153 while True: #body runs once per execution
154 buf = []
155 for i in xrange(N):
156 # jump to task `glet`
157 # with instructions to report results back to this loop `g`
158 _vm, _gself, _dest, _args, _kwargs = vm.switch(glet, gself, args, kwargs)
159 buf.append(_args[0])
160 assert len(_args)==1
161 assert _kwargs=={}
162 assert _gself is gself
163 assert _dest is None # instructions can't tell us where to jump
164 buf = numpy.asarray(buf)
165 vm, gself, dest, args, kwargs = vm.switch(dest, None, (buf,), {})
166 return greenlet(BUFFER_loop)
167
168 def CALL(fn):
73 """ 169 """
74 Create a greenlet whose first argument is the return-jump location. 170 Create a greenlet whose first argument is the return-jump location.
75 171
76 fn must accept as the first positional argument this greenlet itself, which can be used as 172 fn must accept as the first positional argument this greenlet itself, which can be used as
77 the return-jump location for internal greenlet switches (ideally using gswitch). 173 the return-jump location for internal greenlet switches (ideally using gswitch).
78 """ 174 """
79 def service_loop(vm, gself, dest, args, kwargs): 175 def CALL_loop(vm, gself, dest, args, kwargs):
80 while True: 176 while True:
81 #print 'service calling', fn.__name__, args, kwargs 177 #print 'CALL calling', fn.__name__, args, kwargs
82 t = fn(vm, gself, *args, **kwargs) 178 t = fn(*args, **kwargs)
83 #TODO consider a protocol for returning args, kwargs 179 #TODO consider a protocol for returning args, kwargs
84 if t is None: 180 if t is None:
85 _vm,_gself,dest, args, kwargs = vm.switch(dest, None, (), {}) 181 _vm,_gself,dest, args, kwargs = vm.switch(dest, None, (), {})
86 else: 182 else:
87 _vm,_gself,dest, args, kwargs = vm.switch(dest, None, (t,), {}) 183 _vm,_gself,dest, args, kwargs = vm.switch(dest, None, (t,), {})
88 184
89 assert gself is _gself 185 assert gself is _gself
90 return greenlet(service_loop) 186 return greenlet(CALL_loop)
91 187
92 #################################################### 188 ####################################################
189 # Components involved in the learning process
93 190
94 class Dataset(object): 191 class Dataset(object):
95 def __init__(self, data): 192 def __init__(self, data):
96 self.pos = 0 193 self.pos = 0
97 self.data = data 194 self.data = data
98 def next(self, vm, gself): 195 def next(self):
99 rval = self.data[self.pos] 196 rval = self.data[self.pos]
100 self.pos += 1 197 self.pos += 1
101 if self.pos == len(self.data): 198 if self.pos == len(self.data):
102 self.pos = 0 199 self.pos = 0
103 return rval 200 return rval
201 def seek(self, pos):
202 self.pos = pos
203
204 class KFold(object):
205 def __init__(self, data, K):
206 self.data = data
207 self.k = -1
208 self.scores = [None]*K
209 self.K = K
210 def next_fold(self):
211 self.k += 1
212 self.data.seek(0) # restart the stream
213 def next(self):
214 #TODO: skip the examples that are ommitted in this split
215 return self.data.next()
216 def init_test(self):
217 pass
218 def next_test(self):
219 return self.data.next()
220 def test_size(self):
221 return 5
222 def store_scores(self, scores):
223 self.scores[self.k] = scores
104 224
105 class PCA_Analysis(object): 225 class PCA_Analysis(object):
106 def __init__(self): 226 def __init__(self):
227 self.clear()
228
229 def clear(self):
107 self.mean = 0 230 self.mean = 0
108 self.eigvecs=0 231 self.eigvecs=0
109 self.eigvals=0 232 self.eigvals=0
110 def analyze(self, me, X): 233 def analyze(self, X):
111 self.mean = X.mean(axis=0) 234 self.mean = X.mean(axis=0)
112 self.eigvecs=1 235 self.eigvecs=1
113 self.eigvals=1 236 self.eigvals=1
114 def filt(self,me, X): 237 def filt(self, X):
115 return (self.X - self.mean) * self.eigvecs #TODO: divide by root eigvals? 238 return (X - self.mean) * self.eigvecs #TODO: divide by root eigvals?
116 def pseudo_inverse(self, Y): 239 def pseudo_inverse(self, Y):
117 return Y 240 return Y
118 241
119 class Layer(object): 242 class Layer(object):
120 def __init__(self, w): 243 def __init__(self, w):
121 self.w = w 244 self.w = w
122 def filt(self, x): 245 def filt(self, x):
123 return self.w*x 246 return self.w*x
124 247 def clear(self):
125 def batches(src, N): 248 self.w =0
126 # src is a service
127 def rval(me):
128 print 'batches src=', src, 'me=', me
129 return numpy.asarray([gswitch(src, me)[0][0] for i in range(N)])
130 return rval
131 249
132 def print_obj(vm, gself, obj): 250 def print_obj(vm, gself, obj):
133 print obj 251 print obj
134 def no_op(*args, **kwargs): 252 def no_op(*args, **kwargs):
135 pass 253 pass
136 254
137 def build_pca_trainer(data_src, pca_module, N): 255 class cd1_update(object):
138 return greenlet( 256 def __init__(self, layer, lr):
139 batches( 257 self.layer = layer
140 N=5, 258 self.lr = lr
141 src=inf_data, 259
142 dest=flow(pca_module.analyze, 260 def __call__(self, X):
143 dest=layer1_trainer))) 261 # update self.layer from observation X
262 print 'cd1', X
263 print X.mean()
264 self.layer.w += X.mean() * self.lr #TODO: not exactly correct math
144 265
145 def main(): 266 def main():
146 dataset = Dataset(numpy.random.RandomState(123).randn(10,2))
147
148 prog=repeat(3, [service(dataset.next),service(print_obj)])
149 vm_run(prog)
150 vm_run(prog)
151
152
153 def main_arch():
154
155 # create components 267 # create components
156 dataset = Dataset(numpy.random.RandomState(123).randn(10,2)) 268 dataset = Dataset(numpy.random.RandomState(123).randn(13,1))
157 pca_module = PCA_Analysis() 269 pca = PCA_Analysis()
158 layer1 = Layer(w=4) 270 layer1 = Layer(w=4)
159 layer2 = Layer(w=3) 271 layer2 = Layer(w=3)
160 kf = KFold(dataset, K=10) 272 kf = KFold(dataset, K=10)
161 273
162 # create algorithm 274 # create algorithm
163 275
164 train_pca = seq([ np_batch(kf.next, 1000), pca.analyze]) 276 train_pca = SEQ([
165 train_layer1 = repeat(100, [kf.next, pca.filt, cd1_update(layer1, lr=.01)]) 277 BUFFER(1000, CALL(kf.next)),
166 278 CALL(pca.analyze)])
167 algo = repeat(10, [ 279
168 KFold.step, 280 train_layer1 = REPEAT(10, [
169 seq([train_pca, 281 BUFFER(10, CALL(kf.next)),
170 train_layer1, 282 CALL(pca.filt),
171 train_layer2, 283 CALL(cd1_update(layer1, lr=.01))])
172 train_classifier, 284
173 save_classifier, 285 train_layer2 = REPEAT(10, [
174 test_classifier]), 286 BUFFER(10, CALL(kf.next)),
175 KFold.set_score]) 287 CALL(pca.filt),
176 288 CALL(layer1.filt),
177 gswitch(algo) 289 CALL(cd1_update(layer2, lr=.01))])
178 290
179 291 def print_layer_w(*a,**kw):
180 def main1(): 292 print layer1.w
181 dataset = Dataset(numpy.random.RandomState(123).randn(10,2)) 293
182 pca_module = PCA_Analysis() 294 train_prog = SEQ([
183 295 train_pca,
184 # pca 296 WEAVE([
185 next_data = service(dataset.next) 297 train_layer1,
186 b5 = service(batches(src=next_data, N=5)) 298 LOOP(CALL(print_layer_w))]),
187 print_pca_analyze = flow(pca_module.analyze, dest=sink(print_obj)) 299 train_layer2,
188 300 ])
189 # layer1_training 301
190 layer1_training = driver( 302 kfold_prog = REPEAT(10, [
191 fn=cd1_trainer(layer1), 303 CALL(kf.next_fold),
192 srcs=[], 304 CALL(pca.clear),
193 ) 305 CALL(layer1.clear),
194 306 CALL(layer2.clear),
195 gswitch(b5, print_pca_analyze) 307 train_prog,
196 308 CALL(kf.init_test),
309 BUFFER(kf.test_size(),
310 SEQ([
311 CALL(kf.next_test),
312 CALL(pca.filt), # may want to allow this SEQ to be
313 CALL(layer1.filt), # optimized into a shorter one that
314 CALL(layer2.filt),
315 CALL(numpy.mean)])), # chains together theano graphs
316 CALL(kf.store_scores),
317 ])
318
319 vm_run(kfold_prog, n_steps=500)
320 print kf.scores
321
322
197 if __name__ == '__main__': 323 if __name__ == '__main__':
198 sys.exit(main()) 324 sys.exit(main())
199 325
200
201
202 def flow(fn, dest):
203 def rval(*args, **kwargs):
204 while True:
205 print 'flow calling', fn.__name__, args, kwargs
206 t = fn(g, *args, **kwargs)
207 args, kwargs = gswitch(dest, t)
208 g = greenlet(rval)
209 return g
210
211 def sink(fn):
212 def rval(*args, **kwargs):
213 return fn(g, *args, **kwargs)
214 g = greenlet(rval)
215 return g
216
217 def consumer(fn, src):
218 def rval(*args, **kwargs):
219 while True:
220 fn(gswitch(src, *args, **kwargs))
221 return greenlet(rval)