Mercurial > pylearn
comparison doc/v2_planning/plugin_JB.py @ 1199:98954d8cb92d
v2planning - modifs to plugin_JB
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 20 Sep 2010 02:56:11 -0400 |
parents | 1387771296a8 |
children | acfd5e747a75 |
comparison
equal
deleted
inserted
replaced
1198:1387771296a8 | 1199:98954d8cb92d |
---|---|
92 | 92 |
93 def start(self, arg): | 93 def start(self, arg): |
94 pass | 94 pass |
95 def step(self): | 95 def step(self): |
96 pass | 96 pass |
97 | |
98 | 97 |
99 class BUFFER_REPEAT(ELEMENT): | 98 class BUFFER_REPEAT(ELEMENT): |
100 """ | 99 """ |
101 Accumulate a number of return values into one list / array. | 100 Accumulate a number of return values into one list / array. |
102 | 101 |
158 raise TypeError('cant get positional args both ways') | 157 raise TypeError('cant get positional args both ways') |
159 return self.fn(self.start_arg, **self.kwargs) | 158 return self.fn(self.start_arg, **self.kwargs) |
160 else: | 159 else: |
161 return self.fn(*self.args, **self.kwargs) | 160 return self.fn(*self.args, **self.kwargs) |
162 def __getstate__(self): | 161 def __getstate__(self): |
163 rval = self.__dict__ | 162 rval = dict(self.__dict__) |
164 if type(self.fn) is type(self.step): #instancemethod | 163 if type(self.fn) is type(self.step): #instancemethod |
165 fn = rval.pop('fn') | 164 fn = rval.pop('fn') |
166 rval['i fn'] = fn.im_func, fn.im_self, fn.im_class | 165 rval['i fn'] = fn.im_func, fn.im_self, fn.im_class |
167 return rval | 166 return rval |
168 def __setstate__(self, dct): | 167 def __setstate__(self, dct): |
169 if 'i fn' in dct: | 168 if 'i fn' in dct: |
170 dct['fn'] = type(self.step)(*dct.pop('i fn')) | 169 dct['fn'] = type(self.step)(*dct.pop('i fn')) |
171 self.__dict__.update(dct) | 170 self.__dict__.update(dct) |
172 | 171 |
173 def FILT(*args, **kwargs): | 172 def FILT(fn, **kwargs): |
174 return CALL(use_start_arg=True, *args, **kwargs) | 173 """ |
174 Return a CALL object that uses the return value from the previous CALL as the first and | |
175 only positional argument. | |
176 """ | |
177 return CALL(fn, use_start_arg=True, **kwargs) | |
175 | 178 |
176 def CHOOSE(which, options): | 179 def CHOOSE(which, options): |
177 """ | 180 """ |
178 Execute one out of a number of optional control flow paths | 181 Execute one out of a number of optional control flow paths |
179 """ | 182 """ |
282 def test_size(self): | 285 def test_size(self): |
283 return 5 | 286 return 5 |
284 def store_scores(self, scores): | 287 def store_scores(self, scores): |
285 self.scores[self.k] = scores | 288 self.scores[self.k] = scores |
286 | 289 |
290 def prog(self, clear, train, test): | |
291 return REPEAT(self.K, [ | |
292 CALL(self.next_fold), | |
293 clear, | |
294 train, | |
295 CALL(self.init_test), | |
296 BUFFER_REPEAT(self.test_size(), | |
297 SEQ([ CALL(self.next_test), test])), | |
298 FILT(self.store_scores) ]) | |
299 | |
287 class PCA_Analysis(object): | 300 class PCA_Analysis(object): |
288 def __init__(self): | 301 def __init__(self): |
289 self.clear() | 302 self.clear() |
290 | 303 |
291 def clear(self): | 304 def clear(self): |
314 def print_obj_attr(obj, attr): | 327 def print_obj_attr(obj, attr): |
315 print getattr(obj, attr) | 328 print getattr(obj, attr) |
316 def no_op(*args, **kwargs): | 329 def no_op(*args, **kwargs): |
317 pass | 330 pass |
318 | 331 |
319 class cd1_update(object): | 332 def cd1_update(X, layer, lr): |
320 def __init__(self, layer, lr): | 333 # update self.layer from observation X |
321 self.layer = layer | 334 layer.w += X.mean() * lr #TODO: not exactly correct math! |
322 self.lr = lr | |
323 | |
324 def __call__(self, X): | |
325 # update self.layer from observation X | |
326 self.layer.w += X.mean() * self.lr #TODO: not exactly correct math | |
327 | 335 |
328 def simple_main(): | 336 def simple_main(): |
329 | 337 |
330 l = [0] | 338 l = [0] |
331 def f(a): | 339 def f(a): |
344 pca = PCA_Analysis() | 352 pca = PCA_Analysis() |
345 layer1 = Layer(w=4) | 353 layer1 = Layer(w=4) |
346 layer2 = Layer(w=3) | 354 layer2 = Layer(w=3) |
347 kf = KFold(dataset, K=10) | 355 kf = KFold(dataset, K=10) |
348 | 356 |
357 pca_batchsize=1000 | |
358 cd_batchsize = 5 | |
359 n_cd_updates_layer1 = 10 | |
360 n_cd_updates_layer2 = 10 | |
361 | |
349 # create algorithm | 362 # create algorithm |
350 | 363 |
351 train_pca = SEQ([ | 364 train_pca = SEQ([ |
352 BUFFER_REPEAT(1000, CALL(kf.next)), | 365 BUFFER_REPEAT(pca_batchsize, CALL(kf.next)), |
353 FILT(pca.analyze)]) | 366 FILT(pca.analyze)]) |
354 | 367 |
355 train_layer1 = REPEAT(10, [ | 368 train_layer1 = REPEAT(n_cd_updates_layer1, [ |
356 BUFFER_REPEAT(10, CALL(kf.next)), | 369 BUFFER_REPEAT(cd_batchsize, CALL(kf.next)), |
357 FILT(pca.filt), | 370 FILT(pca.filt), |
358 FILT(cd1_update(layer1, lr=.01))]) | 371 FILT(cd1_update, layer=layer1, lr=.01)]) |
359 | 372 |
360 train_layer2 = REPEAT(10, [ | 373 train_layer2 = REPEAT(n_cd_updates_layer2, [ |
361 BUFFER_REPEAT(10, CALL(kf.next)), | 374 BUFFER_REPEAT(cd_batchsize, CALL(kf.next)), |
362 FILT(pca.filt), | 375 FILT(pca.filt), |
363 FILT(layer1.filt), | 376 FILT(layer1.filt), |
364 FILT(cd1_update(layer2, lr=.01))]) | 377 FILT(cd1_update, layer=layer2, lr=.01)]) |
365 | 378 |
366 train_prog = SEQ([ | 379 kfold_prog = kf.prog( |
367 train_pca, | 380 clear = SEQ([ # FRAGMENT 1: this bit is the reset/clear stage |
368 WEAVE([ | 381 CALL(pca.clear), |
369 train_layer1, | 382 CALL(layer1.clear), |
370 LOOP(CALL(print_obj_attr, layer1, 'w'))]), | 383 CALL(layer2.clear), |
371 train_layer2, | 384 ]), |
372 ]) | 385 train = SEQ([ |
373 | 386 train_pca, |
374 kfold_prog = REPEAT(10, [ | 387 WEAVE([ # Silly example of how to do debugging / loggin with WEAVE |
375 CALL(kf.next_fold), | 388 train_layer1, |
376 CALL(pca.clear), | 389 LOOP(CALL(print_obj_attr, layer1, 'w'))]), |
377 CALL(layer1.clear), | 390 train_layer2, |
378 CALL(layer2.clear), | 391 ]), |
379 train_prog, | 392 test=SEQ([ |
380 CALL(kf.init_test), | |
381 BUFFER_REPEAT(kf.test_size(), | |
382 SEQ([ | |
383 CALL(kf.next_test), | |
384 FILT(pca.filt), # may want to allow this SEQ to be | 393 FILT(pca.filt), # may want to allow this SEQ to be |
385 FILT(layer1.filt), # optimized into a shorter one that | 394 FILT(layer1.filt), # optimized into a shorter one that |
386 FILT(layer2.filt), | 395 FILT(layer2.filt), # compiles these calls together with |
387 FILT(numpy.mean)])), # chains together theano graphs | 396 FILT(numpy.mean)])) # Theano |
388 FILT(kf.store_scores), | 397 |
389 ]) | 398 pkg1 = dict(prog=kfold_prog, kf=kf) |
390 | 399 pkg2 = copy.deepcopy(pkg1) # programs can be copied |
391 vm = VirtualMachine(kfold_prog) | 400 |
392 | 401 try: |
393 #vm2 = copy.deepcopy(vm) | 402 pkg3 = cPickle.loads(cPickle.dumps(pkg1)) |
394 vm.run(n_steps=200000) | 403 except: |
395 print kf.scores | 404 print >> sys.stderr, "pickling doesnt work, but it can be fixed I think" |
405 | |
406 pkg = pkg2 | |
407 | |
408 # running a program updates the variables in its package, but not the other package | |
409 VirtualMachine(pkg['prog']).run() | |
410 print pkg['kf'].scores | |
396 | 411 |
397 | 412 |
398 if __name__ == '__main__': | 413 if __name__ == '__main__': |
399 sys.exit(main()) | 414 sys.exit(main()) |
400 | 415 |