Mercurial > pylearn
comparison doc/v2_planning/arch_src/plugin_JB.py @ 1219:9fac28d80fb7
plugin_JB - removed FILT and BUFFER_REPEAT, added Registers
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 22 Sep 2010 13:31:31 -0400 |
parents | 478bb1f8215c |
children |
comparison
equal
deleted
inserted
replaced
1218:5d1b5906151c | 1219:9fac28d80fb7 |
---|---|
8 | 8 |
9 # allocate the relevant modules | 9 # allocate the relevant modules |
10 dataset = Dataset(numpy.random.RandomState(123).randn(13,1)) | 10 dataset = Dataset(numpy.random.RandomState(123).randn(13,1)) |
11 pca = PCA_Analysis() | 11 pca = PCA_Analysis() |
12 pca_batchsize=1000 | 12 pca_batchsize=1000 |
13 | |
14 reg = Registers() | |
13 | 15 |
14 # define the control-flow of the algorithm | 16 # define the control-flow of the algorithm |
15 train_pca = SEQ([ | 17 train_pca = SEQ([ |
16 BUFFER_REPEAT(pca_batchsize, CALL(dataset.next)), | 18 REPEAT(pca_batchsize, CALL(dataset.next, store_to=reg('x'))), |
17 FILT(pca.analyze)]) | 19 CALL(pca.analyze, reg('x'))]) |
18 | 20 |
19 # run the program | 21 # run the program |
20 train_pca.run() | 22 train_pca.run() |
21 | 23 |
22 The CALL, SEQ, FILT, and BUFFER_REPEAT are control-flow elements. The control-flow elements I | 24 The CALL, SEQ, and REPEAT are control-flow elements. The control-flow elements I |
23 defined so far are: | 25 defined so far are: |
24 | 26 |
25 - CALL - a basic statement, just calls a python function | 27 - CALL - a basic statement, just calls a python function |
26 - FILT - like call, but passes the return value of the last CALL or FILT to the python function | |
27 - SEQ - a sequence of elements to run in order | 28 - SEQ - a sequence of elements to run in order |
28 - REPEAT - do something N times (and return None or maybe the last CALL?) | 29 - REPEAT - do something N times (and return None or maybe the last CALL?) |
29 - BUFFER_REPEAT - do something N times and accumulate the return value from each iter | |
30 - LOOP - do something an infinite number of times | 30 - LOOP - do something an infinite number of times |
31 - CHOOSE - like a switch statement (should rename to SWITCH) | 31 - CHOOSE - like a switch statement (should rename to SWITCH) |
32 - WEAVE - interleave execution of multiple control-flow elements | 32 - WEAVE - interleave execution of multiple control-flow elements |
33 - POPEN - launch a process and return its status when it's complete | 33 - POPEN - launch a process and return its status when it's complete |
34 - PRINT - a shortcut for CALL(print_obj) | 34 - PRINT - a shortcut for CALL(print_obj) |
35 - SPAWN - run a program fragment asynchronously in another process | |
35 | 36 |
36 | 37 |
37 We don't have many requirements per-se for the architecture, but I think this design respects | 38 We don't have many requirements per-se for the architecture, but I think this design respects |
38 and realizes all of them. | 39 and realizes all of them. |
39 The advantages of this approach are: | 40 The advantages of this approach are: |
93 - which returns INCOMPLETE | 94 - which returns INCOMPLETE |
94 - which returns any other object to indicate completion | 95 - which returns any other object to indicate completion |
95 """ | 96 """ |
96 | 97 |
97 # subclasses should override these methods: | 98 # subclasses should override these methods: |
98 def start(self, arg): | 99 def start(self): |
99 pass | 100 pass |
100 def step(self): | 101 def step(self): |
101 pass | 102 pass |
102 | 103 |
103 # subclasses should typically not override these: | 104 # subclasses should typically not override these: |
104 def run(self, arg=None, n_steps=float('inf')): | 105 def run(self, n_steps=float('inf')): |
105 self.start(arg) | 106 self.start() |
106 i = 0 | 107 i = 0 |
107 r = self.step() | 108 r = self.step() |
108 while r is INCOMPLETE: | 109 while r is INCOMPLETE: |
109 i += 1 | 110 i += 1 |
110 #TODO make sure there is not an off-by-one error | 111 #TODO make sure there is not an off-by-one error |
160 """ | 161 """ |
161 def __init__(self, fn, *args, **kwargs): | 162 def __init__(self, fn, *args, **kwargs): |
162 self.fn = fn | 163 self.fn = fn |
163 self.args = args | 164 self.args = args |
164 self.kwargs=kwargs | 165 self.kwargs=kwargs |
165 self.use_start_arg = kwargs.pop('use_start_arg', False) | 166 def start(self): |
166 def start(self, arg): | |
167 self.start_arg = arg | |
168 self.finished = False | 167 self.finished = False |
169 return self | 168 return self |
170 def step(self): | 169 def step(self): |
171 assert not self.finished | 170 assert not self.finished |
172 self.finished = True | 171 self.finished = True |
173 if self.use_start_arg: | 172 fn_rval = self.fn(*self.lookup_args(), **self.lookup_kwargs()) |
174 if self.args: | 173 if '_set' in self.kwargs: |
175 raise TypeError('cant get positional args both ways') | 174 self.kwargs['_set'].set(fn_rval) |
176 return self.fn(self.start_arg, **self.kwargs) | |
177 else: | |
178 return self.fn(*self.args, **self.kwargs) | |
179 def __getstate__(self): | 175 def __getstate__(self): |
180 rval = dict(self.__dict__) | 176 rval = dict(self.__dict__) |
181 if type(self.fn) is type(self.step): #instancemethod | 177 if type(self.fn) is type(self.step): #instancemethod |
182 fn = rval.pop('fn') | 178 fn = rval.pop('fn') |
183 rval['i fn'] = fn.im_func, fn.im_self, fn.im_class | 179 rval['i fn'] = fn.im_func, fn.im_self, fn.im_class |
185 def __setstate__(self, dct): | 181 def __setstate__(self, dct): |
186 if 'i fn' in dct: | 182 if 'i fn' in dct: |
187 dct['fn'] = type(self.step)(*dct.pop('i fn')) | 183 dct['fn'] = type(self.step)(*dct.pop('i fn')) |
188 self.__dict__.update(dct) | 184 self.__dict__.update(dct) |
189 | 185 |
190 def FILT(fn, **kwargs): | 186 def lookup_args(self): |
191 """ | 187 rval = [] |
192 Return a CALL object that uses the return value from the previous CALL as the first and | 188 for a in self.args: |
193 only positional argument. | 189 if isinstance(a, Register): |
194 """ | 190 rval.append(a.get()) |
195 return CALL(fn, use_start_arg=True, **kwargs) | 191 else: |
192 rval.append(a) | |
193 return rval | |
194 def lookup_kwargs(self): | |
195 rval = {} | |
196 for k,v in self.kwargs.iteritems(): | |
197 if k == '_set': | |
198 continue | |
199 if isinstance(v, Register): | |
200 rval[k] = v.get() | |
201 else: | |
202 rval[k] = v | |
203 return rval | |
196 | 204 |
197 def CHOOSE(which, options): | 205 def CHOOSE(which, options): |
198 """ | 206 """ |
199 Execute one out of a number of optional control flow paths | 207 Execute one out of a number of optional control flow paths |
200 """ | 208 """ |
201 raise NotImplementedError() | 209 raise NotImplementedError() |
202 | 210 |
203 def LOOP(elements): | 211 def LOOP(element): |
204 #TODO: implement a true infinite loop | 212 #TODO: implement a true infinite loop |
205 try: | 213 return REPEAT(sys.maxint, element) |
206 iter(elements) | |
207 return REPEAT(sys.maxint, elements) | |
208 except TypeError: | |
209 return REPEAT(sys.maxint, [elements]) | |
210 | 214 |
211 class REPEAT(ELEMENT): | 215 class REPEAT(ELEMENT): |
212 def __init__(self, N, elements, pass_rvals=False): | 216 def __init__(self, N, element, counter=None): |
213 self.N = N | 217 self.N = N |
214 self.elements = elements | 218 if not isinstance(element, ELEMENT): |
215 self.pass_rvals = pass_rvals | 219 raise TypeError(element) |
220 self.element = element | |
221 self.counter = counter | |
216 | 222 |
217 #TODO: check for N being callable | 223 #TODO: check for N being callable |
218 def start(self, arg): | 224 def start(self): |
219 self.n = 0 #loop iteration | 225 self.n = 0 #loop iteration |
220 self.idx = 0 #element idx | |
221 self.finished = False | 226 self.finished = False |
222 self.elements[0].start(arg) | 227 self.element.start() |
228 if self.counter: | |
229 self.counter.set(0) | |
230 | |
223 def step(self): | 231 def step(self): |
224 assert not self.finished | 232 assert not self.finished |
225 r = self.elements[self.idx].step() | 233 r = self.element.step() |
226 if r is INCOMPLETE: | 234 if r is INCOMPLETE: |
227 return INCOMPLETE | 235 return INCOMPLETE |
228 self.idx += 1 | |
229 if self.idx < len(self.elements): | |
230 self.elements[self.idx].start(r) | |
231 return INCOMPLETE | |
232 self.n += 1 | 236 self.n += 1 |
237 if self.counter: | |
238 self.counter.set(self.n) | |
233 if self.n < self.N: | 239 if self.n < self.N: |
234 self.idx = 0 | 240 self.element.start() |
235 self.elements[self.idx].start(r) | |
236 return INCOMPLETE | 241 return INCOMPLETE |
237 else: | 242 else: |
238 self.finished = True | 243 self.finished = True |
239 return r | 244 return r |
240 | 245 |
241 def SEQ(elements): | 246 class SEQ(ELEMENT): |
242 return REPEAT(1, elements) | 247 def __init__(self, elements): |
248 self.elements = list(elements) | |
249 def start(self): | |
250 if len(self.elements): | |
251 self.elements[0].start() | |
252 self.pos = 0 | |
253 self.finished = False | |
254 def step(self): | |
255 if self.pos == len(self.elements): | |
256 self.finished=True | |
257 return | |
258 r = self.elements[self.pos].step() | |
259 if r is INCOMPLETE: | |
260 return r | |
261 self.pos += 1 | |
262 if self.pos < len(self.elements): | |
263 self.elements[self.pos].start() | |
264 return INCOMPLETE | |
243 | 265 |
244 class WEAVE(ELEMENT): | 266 class WEAVE(ELEMENT): |
245 """ | 267 """ |
246 Interleave execution of a number of elements. | 268 Interleave execution of a number of elements. |
247 | 269 |
251 self.elements = elements | 273 self.elements = elements |
252 if n_required == -1: | 274 if n_required == -1: |
253 self.n_required = len(elements) | 275 self.n_required = len(elements) |
254 else: | 276 else: |
255 self.n_required = n_required | 277 self.n_required = n_required |
256 def start(self, arg): | 278 def start(self): |
257 for el in self.elements: | 279 for el in self.elements: |
258 el.start(arg) | 280 el.start() |
259 self.elem_finished = [0] * len(self.elements) | 281 self.elem_finished = [0] * len(self.elements) |
260 self.idx = 0 | 282 self.idx = 0 |
261 self.finished= False | 283 self.finished= False |
262 def step(self): | 284 def step(self): |
263 assert not self.finished # if this is triggered, we have a broken driver | 285 assert not self.finished # if this is triggered, we have a broken driver |
287 return INCOMPLETE | 309 return INCOMPLETE |
288 | 310 |
289 class POPEN(ELEMENT): | 311 class POPEN(ELEMENT): |
290 def __init__(self, args): | 312 def __init__(self, args): |
291 self.args = args | 313 self.args = args |
292 def start(self, arg): | 314 def start(self): |
293 self.p = subprocess.Popen(self.args) | 315 self.p = subprocess.Popen(self.args) |
294 def step(self): | 316 def step(self): |
295 r = self.p.poll() | 317 r = self.p.poll() |
296 if r is None: | 318 if r is None: |
297 return INCOMPLETE | 319 return INCOMPLETE |
303 class SPAWN(ELEMENT): | 325 class SPAWN(ELEMENT): |
304 SUCCESS = 0 | 326 SUCCESS = 0 |
305 def __init__(self, data, prog): | 327 def __init__(self, data, prog): |
306 self.data = data | 328 self.data = data |
307 self.prog = prog | 329 self.prog = prog |
308 def start(self, arg): | 330 def start(self): |
309 # pickle the (data, prog) pair | 331 # pickle the (data, prog) pair |
310 s = cPickle.dumps((self.data, self.prog)) | 332 s = cPickle.dumps((self.data, self.prog)) |
311 | 333 |
312 # call python with a stub function that | 334 # call python with a stub function that |
313 # unpickles the data, prog pair and starts running the prog | 335 # unpickles the data, prog pair and starts running the prog |
351 rval = prog.run() | 373 rval = prog.run() |
352 os.write(wpipe, cPickle.dumps(data)) | 374 os.write(wpipe, cPickle.dumps(data)) |
353 return SPAWN.SUCCESS | 375 return SPAWN.SUCCESS |
354 #os.close(wpipe) | 376 #os.close(wpipe) |
355 | 377 |
378 class Register(object): | |
379 def __init__(self, registers, key): | |
380 self.registers = registers | |
381 self.key = key | |
382 def set(self, val): | |
383 self.registers[self.key] = val | |
384 def get(self): | |
385 return self.registers[self.key] | |
386 class Registers(dict): | |
387 def __call__(self, key): | |
388 return Register(self, key) | |
356 | 389 |
357 def print_obj(obj): | 390 def print_obj(obj): |
358 print obj | 391 print obj |
359 def print_obj_attr(obj, attr): | 392 def print_obj_attr(obj, attr): |
360 print getattr(obj, attr) | 393 print getattr(obj, attr) |
362 pass | 395 pass |
363 | 396 |
364 def importable_fn(d): | 397 def importable_fn(d): |
365 d['new key'] = len(d) | 398 d['new key'] = len(d) |
366 | 399 |
400 | |
401 if __name__ == '__main__': | |
402 print 'this is the library file, run "python plugin_JB_main.py"' |