Mercurial > ift6266
diff deep/convolutional_dae/scdae.py @ 288:80ee63c3e749
Add net saving (only the best model) and error saving using SeriesTable
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Fri, 26 Mar 2010 17:24:17 -0400 |
parents | 20ebc1f2a9fe |
children | 518589bfee55 |
line wrap: on
line diff
--- a/deep/convolutional_dae/scdae.py Thu Mar 25 12:20:27 2010 -0400 +++ b/deep/convolutional_dae/scdae.py Fri Mar 26 17:24:17 2010 -0400 @@ -7,6 +7,7 @@ import theano.tensor as T from itertools import izip +from ift6266.utils.seriestables import * class cdae(LayerStack): def __init__(self, filter_size, num_filt, num_in, subsampling, corruption, @@ -68,6 +69,9 @@ n = scdae_net((1,)+img_size, batch_size, filter_sizes, num_filters, subs, noise, mlp_sizes, out_size, dtype, batch_size) + + n.save('start.net') + x = T.fmatrix('x') y = T.ivector('y') @@ -110,12 +114,12 @@ trainf = select_f2(trainf_opt, trainf_reg, batch_size) evalf = select_f2(evalf_opt, evalf_reg, batch_size) - return pretrain_funcs, trainf, evalf + return pretrain_funcs, trainf, evalf, n -def do_pretrain(pretrain_funcs, pretrain_epochs): - for f in pretrain_funcs: - for i in xrange(pretrain_epochs): - f() +def do_pretrain(pretrain_funcs, pretrain_epochs, serie): + for layer, f in enumerate(pretrain_funcs): + for epoch in xrange(pretrain_epochs): + serie.append((layer, epoch), f()) def massage_funcs(train_it, dset, batch_size, pretrain_funcs, trainf, evalf): def pretrain_f(f): @@ -156,16 +160,52 @@ for e in itf(*args, **kwargs): yield e +def create_series(): + import tables + + series = {} + h5f = tables.openFile('series.h5', 'w') + + series['recons_error'] = AccumulatorSeriesWrapper( + base_series=ErrorSeries(error_name='reconstruction_error', + table_name='reconstruction_error', + hdf5_file=h5f, + index_names=('layer', 'epoch'), + title="Reconstruction error (mse)") + reduce_every=100) + + series['training_err'] = AccumulatorSeriesWrapper( + base_series=ErrorSeries(error_name='training_error', + table_name='training_error' + hdf5_file=h5f, + index_names=('iter',), + titles='Training error (nll)') + reduce_every=100) + + series['valid_err'] = ErrorSeries(error_name='valid_error', + table_name='valid_error' + hdf5_file=h5f, + index_names=('iter',), + titles='Validation error (class)') + + series['test_err'] = ErrorSeries(error_name='test_error', + table_name='test_error' + hdf5_file=h5f, + index_names=('iter',), + titles='Test error (class)') + def run_exp(state, channel): from ift6266 import datasets from sgd_opt import sgd_opt import sys, time - channel.save() - # params: bsize, pretrain_lr, train_lr, nfilts1, nfilts2, nftils3, nfilts4 # pretrain_rounds + pylearn.version.record_versions(state, [theano,ift6266,pylearn]) + # TODO: maybe record pynnet version? + channel.save() + dset = dataset.nist_all() nfilts = [] @@ -182,7 +222,7 @@ subs = [(2,2)]*len(nfilts) noise = [state.noise]*len(nfilts) - pretrain_funcs, trainf, evalf = build_funcs( + pretrain_funcs, trainf, evalf, net = build_funcs( img_size=(32, 32), batch_size=state.bsize, filter_sizes=fsizes, @@ -198,11 +238,13 @@ pretrain_fs, train, valid, test = massage_funcs( state.bsize, dset, pretrain_funcs, trainf, evalf) - do_pretrain(pretrain_fs, state.pretrain_rounds) + series = create_series() + + do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error']) sgd_opt(train, valid, test, training_epochs=100000, patience=10000, patience_increase=2., improvement_threshold=0.995, - validation_frequency=2500) + validation_frequency=2500, series=series, net=net) if __name__ == '__main__': from ift6266 import datasets @@ -212,7 +254,7 @@ batch_size = 100 dset = datasets.mnist() - pretrain_funcs, trainf, evalf = build_funcs( + pretrain_funcs, trainf, evalf, net = build_funcs( img_size = (28, 28), batch_size=batch_size, filter_sizes=[(5,5), (3,3)], num_filters=[4, 4], subs=[(2,2), (2,2)], noise=[0.2, 0.2], @@ -227,7 +269,7 @@ print "pretraining ...", sys.stdout.flush() start = time.time() - do_pretrain(pretrain_fs, 2500) + do_pretrain(pretrain_fs, 2500, DummySeries()) end = time.time() print "done (in", end-start, "s)"