comparison deep/convolutional_dae/run_exp.py @ 300:6eab220a7d70

Adjust sgd_opt parameters and use nist_all.
author Arnaud Bergeron <abergeron@gmail.com>
date Mon, 29 Mar 2010 17:54:01 -0400
parents a222af1d0598
children 69109e41983e
comparison
equal deleted inserted replaced
299:a9af079892ce 300:6eab220a7d70
16 16
17 pylearn.version.record_versions(state, [theano, ift6266, pylearn]) 17 pylearn.version.record_versions(state, [theano, ift6266, pylearn])
18 # TODO: maybe record pynnet version? 18 # TODO: maybe record pynnet version?
19 channel.save() 19 channel.save()
20 20
21 dset = datasets.nist_digits() 21 dset = datasets.nist_all()
22 22
23 nfilts = [] 23 nfilts = []
24 if state.nfilts1 != 0: 24 if state.nfilts1 != 0:
25 nfilts.append(state.nfilts1) 25 nfilts.append(state.nfilts1)
26 if state.nfilts2 != 0: 26 if state.nfilts2 != 0:
59 do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error']) 59 do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error'])
60 60
61 print "training ..." 61 print "training ..."
62 sys.stdout.flush() 62 sys.stdout.flush()
63 best_valid, test_score = sgd_opt(train, valid, test, 63 best_valid, test_score = sgd_opt(train, valid, test,
64 training_epochs=100000, patience=10000, 64 training_epochs=1000000, patience=2500,
65 patience_increase=2., 65 patience_increase=2.,
66 improvement_threshold=0.995, 66 improvement_threshold=0.995,
67 validation_frequency=1000, 67 validation_frequency=500,
68 series=series, net=net) 68 series=series, net=net)
69 state.best_valid = best_valid 69 state.best_valid = best_valid
70 state.test_score = test_score 70 state.test_score = test_score
71 channel.save() 71 channel.save()
72 return channel.COMPLETE 72 return channel.COMPLETE