comparison doc/v2_planning/plugin_JB.py @ 1199:98954d8cb92d

v2planning - modifs to plugin_JB
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 20 Sep 2010 02:56:11 -0400
parents 1387771296a8
children acfd5e747a75
comparison
equal deleted inserted replaced
1198:1387771296a8 1199:98954d8cb92d
92 92
93 def start(self, arg): 93 def start(self, arg):
94 pass 94 pass
95 def step(self): 95 def step(self):
96 pass 96 pass
97
98 97
99 class BUFFER_REPEAT(ELEMENT): 98 class BUFFER_REPEAT(ELEMENT):
100 """ 99 """
101 Accumulate a number of return values into one list / array. 100 Accumulate a number of return values into one list / array.
102 101
158 raise TypeError('cant get positional args both ways') 157 raise TypeError('cant get positional args both ways')
159 return self.fn(self.start_arg, **self.kwargs) 158 return self.fn(self.start_arg, **self.kwargs)
160 else: 159 else:
161 return self.fn(*self.args, **self.kwargs) 160 return self.fn(*self.args, **self.kwargs)
162 def __getstate__(self): 161 def __getstate__(self):
163 rval = self.__dict__ 162 rval = dict(self.__dict__)
164 if type(self.fn) is type(self.step): #instancemethod 163 if type(self.fn) is type(self.step): #instancemethod
165 fn = rval.pop('fn') 164 fn = rval.pop('fn')
166 rval['i fn'] = fn.im_func, fn.im_self, fn.im_class 165 rval['i fn'] = fn.im_func, fn.im_self, fn.im_class
167 return rval 166 return rval
168 def __setstate__(self, dct): 167 def __setstate__(self, dct):
169 if 'i fn' in dct: 168 if 'i fn' in dct:
170 dct['fn'] = type(self.step)(*dct.pop('i fn')) 169 dct['fn'] = type(self.step)(*dct.pop('i fn'))
171 self.__dict__.update(dct) 170 self.__dict__.update(dct)
172 171
173 def FILT(*args, **kwargs): 172 def FILT(fn, **kwargs):
174 return CALL(use_start_arg=True, *args, **kwargs) 173 """
174 Return a CALL object that uses the return value from the previous CALL as the first and
175 only positional argument.
176 """
177 return CALL(fn, use_start_arg=True, **kwargs)
175 178
176 def CHOOSE(which, options): 179 def CHOOSE(which, options):
177 """ 180 """
178 Execute one out of a number of optional control flow paths 181 Execute one out of a number of optional control flow paths
179 """ 182 """
282 def test_size(self): 285 def test_size(self):
283 return 5 286 return 5
284 def store_scores(self, scores): 287 def store_scores(self, scores):
285 self.scores[self.k] = scores 288 self.scores[self.k] = scores
286 289
290 def prog(self, clear, train, test):
291 return REPEAT(self.K, [
292 CALL(self.next_fold),
293 clear,
294 train,
295 CALL(self.init_test),
296 BUFFER_REPEAT(self.test_size(),
297 SEQ([ CALL(self.next_test), test])),
298 FILT(self.store_scores) ])
299
287 class PCA_Analysis(object): 300 class PCA_Analysis(object):
288 def __init__(self): 301 def __init__(self):
289 self.clear() 302 self.clear()
290 303
291 def clear(self): 304 def clear(self):
314 def print_obj_attr(obj, attr): 327 def print_obj_attr(obj, attr):
315 print getattr(obj, attr) 328 print getattr(obj, attr)
316 def no_op(*args, **kwargs): 329 def no_op(*args, **kwargs):
317 pass 330 pass
318 331
319 class cd1_update(object): 332 def cd1_update(X, layer, lr):
320 def __init__(self, layer, lr): 333 # update self.layer from observation X
321 self.layer = layer 334 layer.w += X.mean() * lr #TODO: not exactly correct math!
322 self.lr = lr
323
324 def __call__(self, X):
325 # update self.layer from observation X
326 self.layer.w += X.mean() * self.lr #TODO: not exactly correct math
327 335
328 def simple_main(): 336 def simple_main():
329 337
330 l = [0] 338 l = [0]
331 def f(a): 339 def f(a):
344 pca = PCA_Analysis() 352 pca = PCA_Analysis()
345 layer1 = Layer(w=4) 353 layer1 = Layer(w=4)
346 layer2 = Layer(w=3) 354 layer2 = Layer(w=3)
347 kf = KFold(dataset, K=10) 355 kf = KFold(dataset, K=10)
348 356
357 pca_batchsize=1000
358 cd_batchsize = 5
359 n_cd_updates_layer1 = 10
360 n_cd_updates_layer2 = 10
361
349 # create algorithm 362 # create algorithm
350 363
351 train_pca = SEQ([ 364 train_pca = SEQ([
352 BUFFER_REPEAT(1000, CALL(kf.next)), 365 BUFFER_REPEAT(pca_batchsize, CALL(kf.next)),
353 FILT(pca.analyze)]) 366 FILT(pca.analyze)])
354 367
355 train_layer1 = REPEAT(10, [ 368 train_layer1 = REPEAT(n_cd_updates_layer1, [
356 BUFFER_REPEAT(10, CALL(kf.next)), 369 BUFFER_REPEAT(cd_batchsize, CALL(kf.next)),
357 FILT(pca.filt), 370 FILT(pca.filt),
358 FILT(cd1_update(layer1, lr=.01))]) 371 FILT(cd1_update, layer=layer1, lr=.01)])
359 372
360 train_layer2 = REPEAT(10, [ 373 train_layer2 = REPEAT(n_cd_updates_layer2, [
361 BUFFER_REPEAT(10, CALL(kf.next)), 374 BUFFER_REPEAT(cd_batchsize, CALL(kf.next)),
362 FILT(pca.filt), 375 FILT(pca.filt),
363 FILT(layer1.filt), 376 FILT(layer1.filt),
364 FILT(cd1_update(layer2, lr=.01))]) 377 FILT(cd1_update, layer=layer2, lr=.01)])
365 378
366 train_prog = SEQ([ 379 kfold_prog = kf.prog(
367 train_pca, 380 clear = SEQ([ # FRAGMENT 1: this bit is the reset/clear stage
368 WEAVE([ 381 CALL(pca.clear),
369 train_layer1, 382 CALL(layer1.clear),
370 LOOP(CALL(print_obj_attr, layer1, 'w'))]), 383 CALL(layer2.clear),
371 train_layer2, 384 ]),
372 ]) 385 train = SEQ([
373 386 train_pca,
374 kfold_prog = REPEAT(10, [ 387 WEAVE([ # Silly example of how to do debugging / loggin with WEAVE
375 CALL(kf.next_fold), 388 train_layer1,
376 CALL(pca.clear), 389 LOOP(CALL(print_obj_attr, layer1, 'w'))]),
377 CALL(layer1.clear), 390 train_layer2,
378 CALL(layer2.clear), 391 ]),
379 train_prog, 392 test=SEQ([
380 CALL(kf.init_test),
381 BUFFER_REPEAT(kf.test_size(),
382 SEQ([
383 CALL(kf.next_test),
384 FILT(pca.filt), # may want to allow this SEQ to be 393 FILT(pca.filt), # may want to allow this SEQ to be
385 FILT(layer1.filt), # optimized into a shorter one that 394 FILT(layer1.filt), # optimized into a shorter one that
386 FILT(layer2.filt), 395 FILT(layer2.filt), # compiles these calls together with
387 FILT(numpy.mean)])), # chains together theano graphs 396 FILT(numpy.mean)])) # Theano
388 FILT(kf.store_scores), 397
389 ]) 398 pkg1 = dict(prog=kfold_prog, kf=kf)
390 399 pkg2 = copy.deepcopy(pkg1) # programs can be copied
391 vm = VirtualMachine(kfold_prog) 400
392 401 try:
393 #vm2 = copy.deepcopy(vm) 402 pkg3 = cPickle.loads(cPickle.dumps(pkg1))
394 vm.run(n_steps=200000) 403 except:
395 print kf.scores 404 print >> sys.stderr, "pickling doesnt work, but it can be fixed I think"
405
406 pkg = pkg2
407
408 # running a program updates the variables in its package, but not the other package
409 VirtualMachine(pkg['prog']).run()
410 print pkg['kf'].scores
396 411
397 412
398 if __name__ == '__main__': 413 if __name__ == '__main__':
399 sys.exit(main()) 414 sys.exit(main())
400 415