# HG changeset patch # User Arnaud Bergeron # Date 1269643767 14400 # Node ID 8108d271c30c57eefadc660f64435575aa524c6d # Parent 7d1fa2d7721c44bff365ed31c8b405a4c6ded32b Fix stuff (imports, ...) so that it can run under jobman properly. diff -r 7d1fa2d7721c -r 8108d271c30c deep/convolutional_dae/run_exp.py --- a/deep/convolutional_dae/run_exp.py Fri Mar 26 18:35:23 2010 -0400 +++ b/deep/convolutional_dae/run_exp.py Fri Mar 26 18:49:27 2010 -0400 @@ -47,7 +47,9 @@ train_lr=state.train_lr) pretrain_fs, train, valid, test = massage_funcs( - state.bsize, dset, pretrain_funcs, trainf, evalf) + repeat_itf(dset.train, state.bsize), + dset, state.bsize, + pretrain_funcs, trainf,evalf) series = create_series() diff -r 7d1fa2d7721c -r 8108d271c30c deep/convolutional_dae/scdae.py --- a/deep/convolutional_dae/scdae.py Fri Mar 26 18:35:23 2010 -0400 +++ b/deep/convolutional_dae/scdae.py Fri Mar 26 18:49:27 2010 -0400 @@ -174,25 +174,27 @@ title="Reconstruction error (mse)"), reduce_every=100) - series['training_err'] = AccumulatorSeriesWrapper( + series['train_error'] = AccumulatorSeriesWrapper( base_series=ErrorSeries(error_name='training_error', table_name='training_error', hdf5_file=h5f, index_names=('iter',), - titles='Training error (nll)'), + title='Training error (nll)'), reduce_every=100) - - series['valid_err'] = ErrorSeries(error_name='valid_error', - table_name='valid_error', - hdf5_file=h5f, - index_names=('iter',), - titles='Validation error (class)') - - series['test_err'] = ErrorSeries(error_name='test_error', - table_name='test_error', - hdf5_file=h5f, - index_names=('iter',), - titles='Test error (class)') + + series['valid_error'] = ErrorSeries(error_name='valid_error', + table_name='valid_error', + hdf5_file=h5f, + index_names=('iter',), + title='Validation error (class)') + + series['test_error'] = ErrorSeries(error_name='test_error', + table_name='test_error', + hdf5_file=h5f, + index_names=('iter',), + title='Test error (class)') + + return series def run_exp(state, channel): from ift6266 import datasets @@ -206,7 +208,7 @@ # TODO: maybe record pynnet version? channel.save() - dset = dataset.nist_all() + dset = dataset.nist_all(1000) nfilts = [] if state.nfilts1 != 0: