diff deep/convolutional_dae/scdae.py @ 292:8108d271c30c

Fix stuff (imports, ...) so that it can run under jobman properly.
author Arnaud Bergeron <abergeron@gmail.com>
date Fri, 26 Mar 2010 18:49:27 -0400
parents 518589bfee55
children a222af1d0598
line wrap: on
line diff
--- a/deep/convolutional_dae/scdae.py	Fri Mar 26 18:35:23 2010 -0400
+++ b/deep/convolutional_dae/scdae.py	Fri Mar 26 18:49:27 2010 -0400
@@ -174,25 +174,27 @@
                                 title="Reconstruction error (mse)"),
         reduce_every=100)
         
-    series['training_err'] = AccumulatorSeriesWrapper(
+    series['train_error'] = AccumulatorSeriesWrapper(
         base_series=ErrorSeries(error_name='training_error',
                                 table_name='training_error',
                                 hdf5_file=h5f,
                                 index_names=('iter',),
-                                titles='Training error (nll)'),
+                                title='Training error (nll)'),
         reduce_every=100)
-
-    series['valid_err'] = ErrorSeries(error_name='valid_error',
-                                         table_name='valid_error',
-                                         hdf5_file=h5f,
-                                         index_names=('iter',),
-                                         titles='Validation error (class)')
-
-    series['test_err'] = ErrorSeries(error_name='test_error',
-                                         table_name='test_error',
-                                         hdf5_file=h5f,
-                                         index_names=('iter',),
-                                         titles='Test error (class)')
+    
+    series['valid_error'] = ErrorSeries(error_name='valid_error',
+                                        table_name='valid_error',
+                                        hdf5_file=h5f,
+                                        index_names=('iter',),
+                                        title='Validation error (class)')
+    
+    series['test_error'] = ErrorSeries(error_name='test_error',
+                                       table_name='test_error',
+                                       hdf5_file=h5f,
+                                       index_names=('iter',),
+                                       title='Test error (class)')
+    
+    return series
 
 def run_exp(state, channel):
     from ift6266 import datasets
@@ -206,7 +208,7 @@
     # TODO: maybe record pynnet version?
     channel.save()
 
-    dset = dataset.nist_all()
+    dset = dataset.nist_all(1000)
 
     nfilts = []
     if state.nfilts1 != 0: