comparison deep/convolutional_dae/run_exp.py @ 293:d89820070ea0

Add some prints to see the current step.
author Arnaud Bergeron <abergeron@gmail.com>
date Fri, 26 Mar 2010 19:18:03 -0400
parents 8108d271c30c
children 8babd43235dd
comparison
equal deleted inserted replaced
292:8108d271c30c 293:d89820070ea0
7 def go(state, channel): 7 def go(state, channel):
8 from ift6266 import datasets 8 from ift6266 import datasets
9 from ift6266.deep.convolutional_dae.sgd_opt import sgd_opt 9 from ift6266.deep.convolutional_dae.sgd_opt import sgd_opt
10 import pylearn, theano, ift6266 10 import pylearn, theano, ift6266
11 import pylearn.version 11 import pylearn.version
12 import sys
12 13
13 # params: bsize, pretrain_lr, train_lr, nfilts1, nfilts2, nftils3, nfilts4 14 # params: bsize, pretrain_lr, train_lr, nfilts1, nfilts2, nftils3, nfilts4
14 # pretrain_rounds, noise, mlp_sz 15 # pretrain_rounds, noise, mlp_sz
15 16
16 pylearn.version.record_versions(state, [theano, ift6266, pylearn]) 17 pylearn.version.record_versions(state, [theano, ift6266, pylearn])
17 # TODO: maybe record pynnet version? 18 # TODO: maybe record pynnet version?
18 channel.save() 19 channel.save()
19 20
20 dset = datasets.nist_all(1000) 21 dset = datasets.nist_digits()
21 22
22 nfilts = [] 23 nfilts = []
23 if state.nfilts1 != 0: 24 if state.nfilts1 != 0:
24 nfilts.append(state.nfilts1) 25 nfilts.append(state.nfilts1)
25 if state.nfilts2 != 0: 26 if state.nfilts2 != 0:
51 dset, state.bsize, 52 dset, state.bsize,
52 pretrain_funcs, trainf,evalf) 53 pretrain_funcs, trainf,evalf)
53 54
54 series = create_series() 55 series = create_series()
55 56
57 print "pretraining ..."
58 sys.stdout.flush()
56 do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error']) 59 do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error'])
57 60
61 print "training ..."
62 sys.stdout.flush()
58 sgd_opt(train, valid, test, training_epochs=100000, patience=10000, 63 sgd_opt(train, valid, test, training_epochs=100000, patience=10000,
59 patience_increase=2., improvement_threshold=0.995, 64 patience_increase=2., improvement_threshold=0.995,
60 validation_frequency=2500, series=series, net=net) 65 validation_frequency=1000, series=series, net=net)
61 66
62 if __name__ == '__main__': 67 if __name__ == '__main__':
63 st = dumb() 68 st = dumb()
64 st.bsize = 100 69 st.bsize = 100
65 st.pretrain_lr = 0.01 70 st.pretrain_lr = 0.01