comparison deep/convolutional_dae/scdae.py @ 292:8108d271c30c

Fix stuff (imports, ...) so that it can run under jobman properly.
author Arnaud Bergeron <abergeron@gmail.com>
date Fri, 26 Mar 2010 18:49:27 -0400
parents 518589bfee55
children a222af1d0598
comparison
equal deleted inserted replaced
291:7d1fa2d7721c 292:8108d271c30c
172 hdf5_file=h5f, 172 hdf5_file=h5f,
173 index_names=('layer', 'epoch'), 173 index_names=('layer', 'epoch'),
174 title="Reconstruction error (mse)"), 174 title="Reconstruction error (mse)"),
175 reduce_every=100) 175 reduce_every=100)
176 176
177 series['training_err'] = AccumulatorSeriesWrapper( 177 series['train_error'] = AccumulatorSeriesWrapper(
178 base_series=ErrorSeries(error_name='training_error', 178 base_series=ErrorSeries(error_name='training_error',
179 table_name='training_error', 179 table_name='training_error',
180 hdf5_file=h5f, 180 hdf5_file=h5f,
181 index_names=('iter',), 181 index_names=('iter',),
182 titles='Training error (nll)'), 182 title='Training error (nll)'),
183 reduce_every=100) 183 reduce_every=100)
184 184
185 series['valid_err'] = ErrorSeries(error_name='valid_error', 185 series['valid_error'] = ErrorSeries(error_name='valid_error',
186 table_name='valid_error', 186 table_name='valid_error',
187 hdf5_file=h5f, 187 hdf5_file=h5f,
188 index_names=('iter',), 188 index_names=('iter',),
189 titles='Validation error (class)') 189 title='Validation error (class)')
190 190
191 series['test_err'] = ErrorSeries(error_name='test_error', 191 series['test_error'] = ErrorSeries(error_name='test_error',
192 table_name='test_error', 192 table_name='test_error',
193 hdf5_file=h5f, 193 hdf5_file=h5f,
194 index_names=('iter',), 194 index_names=('iter',),
195 titles='Test error (class)') 195 title='Test error (class)')
196
197 return series
196 198
197 def run_exp(state, channel): 199 def run_exp(state, channel):
198 from ift6266 import datasets 200 from ift6266 import datasets
199 from sgd_opt import sgd_opt 201 from sgd_opt import sgd_opt
200 import sys, time 202 import sys, time
204 206
205 pylearn.version.record_versions(state, [theano,ift6266,pylearn]) 207 pylearn.version.record_versions(state, [theano,ift6266,pylearn])
206 # TODO: maybe record pynnet version? 208 # TODO: maybe record pynnet version?
207 channel.save() 209 channel.save()
208 210
209 dset = dataset.nist_all() 211 dset = dataset.nist_all(1000)
210 212
211 nfilts = [] 213 nfilts = []
212 if state.nfilts1 != 0: 214 if state.nfilts1 != 0:
213 nfilts.append(state.nfilts1) 215 nfilts.append(state.nfilts1)
214 if state.nfilts2 != 0: 216 if state.nfilts2 != 0: