changeset 240:f213a0fb2b08

Corrigé bugs dans stacked_dae/v2 en rapport avec les modifs d'hier
author fsavard
date Tue, 16 Mar 2010 10:52:04 -0400
parents 42005ec87747
children c24020aa38ac 8a00764ea8a4
files deep/stacked_dae/v2/sgd_optimization.py
diffstat 1 files changed, 8 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/deep/stacked_dae/v2/sgd_optimization.py	Mon Mar 15 18:30:21 2010 -0400
+++ b/deep/stacked_dae/v2/sgd_optimization.py	Tue Mar 16 10:52:04 2010 -0400
@@ -29,7 +29,7 @@
     for i,it in enumerate(iter):
         if i >= max:
             break
-        yield i
+        yield it
 
 class SdaSgdOptimizer:
     def __init__(self, dataset, hyperparameters, n_ins, n_outs,
@@ -98,8 +98,8 @@
                     self.series["reconstruction_error"].append((epoch, batch_index), c)
                     batch_index+=1
 
-                    if batch_index % 10000 == 0:
-                        print "10000 batches"
+                    if batch_index % 100 == 0:
+                        print "100 batches"
 
                     # useful when doing tests
                     if self.max_minibatches and batch_index >= self.max_minibatches:
@@ -150,6 +150,8 @@
                                       # minibatche before checking the network 
                                       # on the validation set; in this case we 
                                       # check every epoch 
+        if self.max_minibatches and validation_frequency > self.max_minibatches:
+            validation_frequency = self.max_minibatches / 2
 
         best_params          = None
         best_validation_loss = float('inf')
@@ -183,7 +185,7 @@
                         append((epoch, minibatch_index), this_validation_loss*100.)
 
                     print('epoch %i, minibatch %i/%i, validation error %f %%' % \
-                           (epoch, minibatch_index+1, self.n_train_batches, \
+                           (epoch, minibatch_index+1, self.mb_per_epoch, \
                             this_validation_loss*100.))
 
 
@@ -211,13 +213,13 @@
 
                         print(('     epoch %i, minibatch %i/%i, test error of best '
                               'model %f %%') % 
-                                     (epoch, minibatch_index+1, self.n_train_batches,
+                                     (epoch, minibatch_index+1, self.mb_per_epoch,
                                       test_score*100.))
 
                     sys.stdout.flush()
 
                 # useful when doing tests
-                if self.max_minibatches and batch_index >= self.max_minibatches:
+                if self.max_minibatches and minibatch_index >= self.max_minibatches:
                     break
 
             self.series['params'].append((epoch,), self.classifier.all_params)