Mercurial > ift6266
comparison deep/convolutional_dae/run_exp.py @ 291:7d1fa2d7721c
Split out the run_exp method.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Fri, 26 Mar 2010 18:35:23 -0400 |
parents | |
children | 8108d271c30c |
comparison
equal
deleted
inserted
replaced
290:518589bfee55 | 291:7d1fa2d7721c |
---|---|
1 from ift6266.deep.convolutional_dae.scdae import * | |
2 | |
3 class dumb(object): | |
4 def save(self): | |
5 pass | |
6 | |
7 def go(state, channel): | |
8 from ift6266 import datasets | |
9 from ift6266.deep.convolutional_dae.sgd_opt import sgd_opt | |
10 import pylearn, theano, ift6266 | |
11 import pylearn.version | |
12 | |
13 # params: bsize, pretrain_lr, train_lr, nfilts1, nfilts2, nftils3, nfilts4 | |
14 # pretrain_rounds, noise, mlp_sz | |
15 | |
16 pylearn.version.record_versions(state, [theano, ift6266, pylearn]) | |
17 # TODO: maybe record pynnet version? | |
18 channel.save() | |
19 | |
20 dset = datasets.nist_all(1000) | |
21 | |
22 nfilts = [] | |
23 if state.nfilts1 != 0: | |
24 nfilts.append(state.nfilts1) | |
25 if state.nfilts2 != 0: | |
26 nfilts.append(state.nfilts2) | |
27 if state.nfilts3 != 0: | |
28 nfilts.append(state.nfilts3) | |
29 if state.nfilts4 != 0: | |
30 nfilts.append(state.nfilts4) | |
31 | |
32 fsizes = [(5,5)]*len(nfilts) | |
33 subs = [(2,2)]*len(nfilts) | |
34 noise = [state.noise]*len(nfilts) | |
35 | |
36 pretrain_funcs, trainf, evalf, net = build_funcs( | |
37 img_size=(32, 32), | |
38 batch_size=state.bsize, | |
39 filter_sizes=fsizes, | |
40 num_filters=nfilts, | |
41 subs=subs, | |
42 noise=noise, | |
43 mlp_sizes=[state.mlp_sz], | |
44 out_size=62, | |
45 dtype=numpy.float32, | |
46 pretrain_lr=state.pretrain_lr, | |
47 train_lr=state.train_lr) | |
48 | |
49 pretrain_fs, train, valid, test = massage_funcs( | |
50 state.bsize, dset, pretrain_funcs, trainf, evalf) | |
51 | |
52 series = create_series() | |
53 | |
54 do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error']) | |
55 | |
56 sgd_opt(train, valid, test, training_epochs=100000, patience=10000, | |
57 patience_increase=2., improvement_threshold=0.995, | |
58 validation_frequency=2500, series=series, net=net) | |
59 | |
60 if __name__ == '__main__': | |
61 st = dumb() | |
62 st.bsize = 100 | |
63 st.pretrain_lr = 0.01 | |
64 st.train_lr = 0.1 | |
65 st.nfilts1 = 4 | |
66 st.nfilts2 = 4 | |
67 st.nfilts3 = 0 | |
68 st.pretrain_rounds = 500 | |
69 st.noise=0.2 | |
70 st.mlp_sz = 500 | |
71 go(st, dumb()) |