diff deep/stacked_dae/sgd_optimization.py @ 284:8a3af19ae272

Enlevé mécanique pour limiter le nombre d'exemples utilisés (remplacé par paramètre dans l'appel au code de dataset), et ajouté option pour sauvegarde des poids à la fin de l'entraînement
author fsavard
date Wed, 24 Mar 2010 15:13:48 -0400
parents 7b4507295eba
children
line wrap: on
line diff
--- a/deep/stacked_dae/sgd_optimization.py	Wed Mar 24 14:36:55 2010 -0400
+++ b/deep/stacked_dae/sgd_optimization.py	Wed Mar 24 15:13:48 2010 -0400
@@ -3,6 +3,8 @@
 
 # Generic SdA optimization loop, adapted from the deeplearning.net tutorial
 
+from __future__ import with_statement
+
 import numpy 
 import theano
 import time
@@ -25,22 +27,16 @@
         'params' : DummySeries()
         }
 
-def itermax(iter, max):
-    for i,it in enumerate(iter):
-        if i >= max:
-            break
-        yield it
-
 class SdaSgdOptimizer:
     def __init__(self, dataset, hyperparameters, n_ins, n_outs,
-                    examples_per_epoch, series=default_series, max_minibatches=None):
+                    examples_per_epoch, series=default_series, 
+                    save_params=False):
         self.dataset = dataset
         self.hp = hyperparameters
         self.n_ins = n_ins
         self.n_outs = n_outs
-   
-        self.max_minibatches = max_minibatches
-        print "SdaSgdOptimizer, max_minibatches =", max_minibatches
+
+        self.save_params = save_params
 
         self.ex_per_epoch = examples_per_epoch
         self.mb_per_epoch = examples_per_epoch / self.hp.minibatch_size
@@ -101,10 +97,6 @@
                     #if batch_index % 100 == 0:
                     #    print "100 batches"
 
-                    # useful when doing tests
-                    if self.max_minibatches and batch_index >= self.max_minibatches:
-                        break
-                        
                 print 'Pre-training layer %i, epoch %d, cost '%(i,epoch),c
                 sys.stdout.flush()
 
@@ -150,8 +142,6 @@
                                       # minibatche before checking the network 
                                       # on the validation set; in this case we 
                                       # check every epoch 
-        if self.max_minibatches and validation_frequency > self.max_minibatches:
-            validation_frequency = self.max_minibatches / 2
 
         best_params          = None
         best_validation_loss = float('inf')
@@ -176,8 +166,6 @@
                 if (total_mb_index+1) % validation_frequency == 0: 
                     
                     iter = dataset.valid(minibatch_size)
-                    if self.max_minibatches:
-                        iter = itermax(iter, self.max_minibatches)
                     validation_losses = [validate_model(x,y) for x,y in iter]
                     this_validation_loss = numpy.mean(validation_losses)
 
@@ -203,8 +191,6 @@
 
                         # test it on the test set
                         iter = dataset.test(minibatch_size)
-                        if self.max_minibatches:
-                            iter = itermax(iter, self.max_minibatches)
                         test_losses = [test_model(x,y) for x,y in iter]
                         test_score = numpy.mean(test_losses)
 
@@ -218,10 +204,6 @@
 
                     sys.stdout.flush()
 
-                # useful when doing tests
-                if self.max_minibatches and minibatch_index >= self.max_minibatches:
-                    break
-
             self.series['params'].append((epoch,), self.classifier.all_params)
 
             if patience <= total_mb_index:
@@ -234,6 +216,9 @@
                     'test_score':test_score,
                     'num_finetuning_epochs':epoch})
 
+        if self.save_params:
+            save_params(self.classifier.all_params, "weights.dat")
+
         print(('Optimization complete with best validation score of %f %%,'
                'with test performance %f %%') %  
                      (best_validation_loss * 100., test_score*100.))
@@ -241,3 +226,11 @@
 
 
 
+def save_params(all_params, filename):
+    import pickle
+    with open(filename, 'wb') as f:
+        values = [p.value for p in all_params]
+
+        # -1 for HIGHEST_PROTOCOL
+        pickle.dump(values, f, -1)
+