Mercurial > ift6266
comparison deep/convolutional_dae/scdae.py @ 292:8108d271c30c
Fix stuff (imports, ...) so that it can run under jobman properly.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Fri, 26 Mar 2010 18:49:27 -0400 |
parents | 518589bfee55 |
children | a222af1d0598 |
comparison
equal
deleted
inserted
replaced
291:7d1fa2d7721c | 292:8108d271c30c |
---|---|
172 hdf5_file=h5f, | 172 hdf5_file=h5f, |
173 index_names=('layer', 'epoch'), | 173 index_names=('layer', 'epoch'), |
174 title="Reconstruction error (mse)"), | 174 title="Reconstruction error (mse)"), |
175 reduce_every=100) | 175 reduce_every=100) |
176 | 176 |
177 series['training_err'] = AccumulatorSeriesWrapper( | 177 series['train_error'] = AccumulatorSeriesWrapper( |
178 base_series=ErrorSeries(error_name='training_error', | 178 base_series=ErrorSeries(error_name='training_error', |
179 table_name='training_error', | 179 table_name='training_error', |
180 hdf5_file=h5f, | 180 hdf5_file=h5f, |
181 index_names=('iter',), | 181 index_names=('iter',), |
182 titles='Training error (nll)'), | 182 title='Training error (nll)'), |
183 reduce_every=100) | 183 reduce_every=100) |
184 | 184 |
185 series['valid_err'] = ErrorSeries(error_name='valid_error', | 185 series['valid_error'] = ErrorSeries(error_name='valid_error', |
186 table_name='valid_error', | 186 table_name='valid_error', |
187 hdf5_file=h5f, | 187 hdf5_file=h5f, |
188 index_names=('iter',), | 188 index_names=('iter',), |
189 titles='Validation error (class)') | 189 title='Validation error (class)') |
190 | 190 |
191 series['test_err'] = ErrorSeries(error_name='test_error', | 191 series['test_error'] = ErrorSeries(error_name='test_error', |
192 table_name='test_error', | 192 table_name='test_error', |
193 hdf5_file=h5f, | 193 hdf5_file=h5f, |
194 index_names=('iter',), | 194 index_names=('iter',), |
195 titles='Test error (class)') | 195 title='Test error (class)') |
196 | |
197 return series | |
196 | 198 |
197 def run_exp(state, channel): | 199 def run_exp(state, channel): |
198 from ift6266 import datasets | 200 from ift6266 import datasets |
199 from sgd_opt import sgd_opt | 201 from sgd_opt import sgd_opt |
200 import sys, time | 202 import sys, time |
204 | 206 |
205 pylearn.version.record_versions(state, [theano,ift6266,pylearn]) | 207 pylearn.version.record_versions(state, [theano,ift6266,pylearn]) |
206 # TODO: maybe record pynnet version? | 208 # TODO: maybe record pynnet version? |
207 channel.save() | 209 channel.save() |
208 | 210 |
209 dset = dataset.nist_all() | 211 dset = dataset.nist_all(1000) |
210 | 212 |
211 nfilts = [] | 213 nfilts = [] |
212 if state.nfilts1 != 0: | 214 if state.nfilts1 != 0: |
213 nfilts.append(state.nfilts1) | 215 nfilts.append(state.nfilts1) |
214 if state.nfilts2 != 0: | 216 if state.nfilts2 != 0: |