Mercurial > ift6266
changeset 291:7d1fa2d7721c
Split out the run_exp method.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Fri, 26 Mar 2010 18:35:23 -0400 |
parents | 518589bfee55 |
children | 8108d271c30c |
files | deep/convolutional_dae/run_exp.py |
diffstat | 1 files changed, 71 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/deep/convolutional_dae/run_exp.py Fri Mar 26 18:35:23 2010 -0400 @@ -0,0 +1,71 @@ +from ift6266.deep.convolutional_dae.scdae import * + +class dumb(object): + def save(self): + pass + +def go(state, channel): + from ift6266 import datasets + from ift6266.deep.convolutional_dae.sgd_opt import sgd_opt + import pylearn, theano, ift6266 + import pylearn.version + + # params: bsize, pretrain_lr, train_lr, nfilts1, nfilts2, nftils3, nfilts4 + # pretrain_rounds, noise, mlp_sz + + pylearn.version.record_versions(state, [theano, ift6266, pylearn]) + # TODO: maybe record pynnet version? + channel.save() + + dset = datasets.nist_all(1000) + + nfilts = [] + if state.nfilts1 != 0: + nfilts.append(state.nfilts1) + if state.nfilts2 != 0: + nfilts.append(state.nfilts2) + if state.nfilts3 != 0: + nfilts.append(state.nfilts3) + if state.nfilts4 != 0: + nfilts.append(state.nfilts4) + + fsizes = [(5,5)]*len(nfilts) + subs = [(2,2)]*len(nfilts) + noise = [state.noise]*len(nfilts) + + pretrain_funcs, trainf, evalf, net = build_funcs( + img_size=(32, 32), + batch_size=state.bsize, + filter_sizes=fsizes, + num_filters=nfilts, + subs=subs, + noise=noise, + mlp_sizes=[state.mlp_sz], + out_size=62, + dtype=numpy.float32, + pretrain_lr=state.pretrain_lr, + train_lr=state.train_lr) + + pretrain_fs, train, valid, test = massage_funcs( + state.bsize, dset, pretrain_funcs, trainf, evalf) + + series = create_series() + + do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error']) + + sgd_opt(train, valid, test, training_epochs=100000, patience=10000, + patience_increase=2., improvement_threshold=0.995, + validation_frequency=2500, series=series, net=net) + +if __name__ == '__main__': + st = dumb() + st.bsize = 100 + st.pretrain_lr = 0.01 + st.train_lr = 0.1 + st.nfilts1 = 4 + st.nfilts2 = 4 + st.nfilts3 = 0 + st.pretrain_rounds = 500 + st.noise=0.2 + st.mlp_sz = 500 + go(st, dumb())