comparison deep/convolutional_dae/run_exp.py @ 380:0473b799d449

merge
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 26 Apr 2010 14:56:34 -0400
parents 01445a75c702
children
comparison
equal deleted inserted replaced
379:a21a174c1c18 380:0473b799d449
1 from ift6266.deep.convolutional_dae.scdae import * 1 from ift6266.deep.convolutional_dae.scdae import *
2 2
3 class dumb(object): 3 class dumb(object):
4 COMPLETE = None
4 def save(self): 5 def save(self):
5 pass 6 pass
6 7
7 def go(state, channel): 8 def go(state, channel):
8 from ift6266 import datasets 9 from ift6266 import datasets
16 17
17 pylearn.version.record_versions(state, [theano, ift6266, pylearn]) 18 pylearn.version.record_versions(state, [theano, ift6266, pylearn])
18 # TODO: maybe record pynnet version? 19 # TODO: maybe record pynnet version?
19 channel.save() 20 channel.save()
20 21
21 dset = datasets.nist_all() 22 dset = datasets.nist_P07()
22 23
23 nfilts = [] 24 nfilts = []
25 fsizes = []
24 if state.nfilts1 != 0: 26 if state.nfilts1 != 0:
25 nfilts.append(state.nfilts1) 27 nfilts.append(state.nfilts1)
28 fsizes.append((5,5))
26 if state.nfilts2 != 0: 29 if state.nfilts2 != 0:
27 nfilts.append(state.nfilts2) 30 nfilts.append(state.nfilts2)
31 fsizes.append((3,3))
28 if state.nfilts3 != 0: 32 if state.nfilts3 != 0:
29 nfilts.append(state.nfilts3) 33 nfilts.append(state.nfilts3)
34 fsizes.append((3,3))
30 if state.nfilts4 != 0: 35 if state.nfilts4 != 0:
31 nfilts.append(state.nfilts4) 36 nfilts.append(state.nfilts4)
37 fsizes.append((2,2))
32 38
33 fsizes = [(5,5)]*len(nfilts)
34 subs = [(2,2)]*len(nfilts) 39 subs = [(2,2)]*len(nfilts)
35 noise = [state.noise]*len(nfilts) 40 noise = [state.noise]*len(nfilts)
36 41
37 pretrain_funcs, trainf, evalf, net = build_funcs( 42 pretrain_funcs, trainf, evalf, net = build_funcs(
38 img_size=(32, 32), 43 img_size=(32, 32),
59 do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error']) 64 do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error'])
60 65
61 print "training ..." 66 print "training ..."
62 sys.stdout.flush() 67 sys.stdout.flush()
63 best_valid, test_score = sgd_opt(train, valid, test, 68 best_valid, test_score = sgd_opt(train, valid, test,
64 training_epochs=1000000, patience=2500, 69 training_epochs=800000, patience=2000,
65 patience_increase=2., 70 patience_increase=2.,
66 improvement_threshold=0.995, 71 improvement_threshold=0.995,
67 validation_frequency=500, 72 validation_frequency=500,
68 series=series, net=net) 73 series=series, net=net)
69 state.best_valid = best_valid 74 state.best_valid = best_valid