changeset 292:8108d271c30c

Fix stuff (imports, ...) so that it can run under jobman properly.
author Arnaud Bergeron <abergeron@gmail.com>
date Fri, 26 Mar 2010 18:49:27 -0400
parents 7d1fa2d7721c
children d89820070ea0
files deep/convolutional_dae/run_exp.py deep/convolutional_dae/scdae.py
diffstat 2 files changed, 20 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/deep/convolutional_dae/run_exp.py	Fri Mar 26 18:35:23 2010 -0400
+++ b/deep/convolutional_dae/run_exp.py	Fri Mar 26 18:49:27 2010 -0400
@@ -47,7 +47,9 @@
         train_lr=state.train_lr)
 
     pretrain_fs, train, valid, test = massage_funcs(
-        state.bsize, dset, pretrain_funcs, trainf, evalf)
+        repeat_itf(dset.train, state.bsize), 
+        dset, state.bsize, 
+        pretrain_funcs, trainf,evalf)
 
     series = create_series()
 
--- a/deep/convolutional_dae/scdae.py	Fri Mar 26 18:35:23 2010 -0400
+++ b/deep/convolutional_dae/scdae.py	Fri Mar 26 18:49:27 2010 -0400
@@ -174,25 +174,27 @@
                                 title="Reconstruction error (mse)"),
         reduce_every=100)
         
-    series['training_err'] = AccumulatorSeriesWrapper(
+    series['train_error'] = AccumulatorSeriesWrapper(
         base_series=ErrorSeries(error_name='training_error',
                                 table_name='training_error',
                                 hdf5_file=h5f,
                                 index_names=('iter',),
-                                titles='Training error (nll)'),
+                                title='Training error (nll)'),
         reduce_every=100)
-
-    series['valid_err'] = ErrorSeries(error_name='valid_error',
-                                         table_name='valid_error',
-                                         hdf5_file=h5f,
-                                         index_names=('iter',),
-                                         titles='Validation error (class)')
-
-    series['test_err'] = ErrorSeries(error_name='test_error',
-                                         table_name='test_error',
-                                         hdf5_file=h5f,
-                                         index_names=('iter',),
-                                         titles='Test error (class)')
+    
+    series['valid_error'] = ErrorSeries(error_name='valid_error',
+                                        table_name='valid_error',
+                                        hdf5_file=h5f,
+                                        index_names=('iter',),
+                                        title='Validation error (class)')
+    
+    series['test_error'] = ErrorSeries(error_name='test_error',
+                                       table_name='test_error',
+                                       hdf5_file=h5f,
+                                       index_names=('iter',),
+                                       title='Test error (class)')
+    
+    return series
 
 def run_exp(state, channel):
     from ift6266 import datasets
@@ -206,7 +208,7 @@
     # TODO: maybe record pynnet version?
     channel.save()
 
-    dset = dataset.nist_all()
+    dset = dataset.nist_all(1000)
 
     nfilts = []
     if state.nfilts1 != 0: