Mercurial > ift6266
comparison deep/convolutional_dae/run_exp.py @ 293:d89820070ea0
Add some prints to see the current step.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Fri, 26 Mar 2010 19:18:03 -0400 |
parents | 8108d271c30c |
children | 8babd43235dd |
comparison
equal
deleted
inserted
replaced
292:8108d271c30c | 293:d89820070ea0 |
---|---|
7 def go(state, channel): | 7 def go(state, channel): |
8 from ift6266 import datasets | 8 from ift6266 import datasets |
9 from ift6266.deep.convolutional_dae.sgd_opt import sgd_opt | 9 from ift6266.deep.convolutional_dae.sgd_opt import sgd_opt |
10 import pylearn, theano, ift6266 | 10 import pylearn, theano, ift6266 |
11 import pylearn.version | 11 import pylearn.version |
12 import sys | |
12 | 13 |
13 # params: bsize, pretrain_lr, train_lr, nfilts1, nfilts2, nftils3, nfilts4 | 14 # params: bsize, pretrain_lr, train_lr, nfilts1, nfilts2, nftils3, nfilts4 |
14 # pretrain_rounds, noise, mlp_sz | 15 # pretrain_rounds, noise, mlp_sz |
15 | 16 |
16 pylearn.version.record_versions(state, [theano, ift6266, pylearn]) | 17 pylearn.version.record_versions(state, [theano, ift6266, pylearn]) |
17 # TODO: maybe record pynnet version? | 18 # TODO: maybe record pynnet version? |
18 channel.save() | 19 channel.save() |
19 | 20 |
20 dset = datasets.nist_all(1000) | 21 dset = datasets.nist_digits() |
21 | 22 |
22 nfilts = [] | 23 nfilts = [] |
23 if state.nfilts1 != 0: | 24 if state.nfilts1 != 0: |
24 nfilts.append(state.nfilts1) | 25 nfilts.append(state.nfilts1) |
25 if state.nfilts2 != 0: | 26 if state.nfilts2 != 0: |
51 dset, state.bsize, | 52 dset, state.bsize, |
52 pretrain_funcs, trainf,evalf) | 53 pretrain_funcs, trainf,evalf) |
53 | 54 |
54 series = create_series() | 55 series = create_series() |
55 | 56 |
57 print "pretraining ..." | |
58 sys.stdout.flush() | |
56 do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error']) | 59 do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error']) |
57 | 60 |
61 print "training ..." | |
62 sys.stdout.flush() | |
58 sgd_opt(train, valid, test, training_epochs=100000, patience=10000, | 63 sgd_opt(train, valid, test, training_epochs=100000, patience=10000, |
59 patience_increase=2., improvement_threshold=0.995, | 64 patience_increase=2., improvement_threshold=0.995, |
60 validation_frequency=2500, series=series, net=net) | 65 validation_frequency=1000, series=series, net=net) |
61 | 66 |
62 if __name__ == '__main__': | 67 if __name__ == '__main__': |
63 st = dumb() | 68 st = dumb() |
64 st.bsize = 100 | 69 st.bsize = 100 |
65 st.pretrain_lr = 0.01 | 70 st.pretrain_lr = 0.01 |