changeset 284:8a3af19ae272

Enlevé mécanique pour limiter le nombre d'exemples utilisés (remplacé par paramètre dans l'appel au code de dataset), et ajouté option pour sauvegarde des poids à la fin de l'entraînement
author fsavard
date Wed, 24 Mar 2010 15:13:48 -0400
parents 206374eed2fb
children 694e75413413
files deep/stacked_dae/config.py.example deep/stacked_dae/nist_sda.py deep/stacked_dae/sgd_optimization.py
diffstat 3 files changed, 32 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/deep/stacked_dae/config.py.example	Wed Mar 24 14:36:55 2010 -0400
+++ b/deep/stacked_dae/config.py.example	Wed Mar 24 15:13:48 2010 -0400
@@ -59,6 +59,9 @@
 # Set this PRIOR to inserting your test jobs in the DB.
 TEST_CONFIG = False
 
+# save params at training end
+SAVE_PARAMS = False
+
 NIST_ALL_LOCATION = '/data/lisa/data/nist/by_class/all'
 NIST_ALL_TRAIN_SIZE = 649081
 # valid et test =82587 82587 
--- a/deep/stacked_dae/nist_sda.py	Wed Mar 24 14:36:55 2010 -0400
+++ b/deep/stacked_dae/nist_sda.py	Wed Mar 24 15:13:48 2010 -0400
@@ -56,22 +56,30 @@
     n_outs = 62 # 10 digits, 26*2 (lower, capitals)
      
     examples_per_epoch = NIST_ALL_TRAIN_SIZE
+    if rtt:
+        examples_per_epoch = rtt
 
     series = create_series(state.num_hidden_layers)
 
     print "Creating optimizer with state, ", state
 
-    optimizer = SdaSgdOptimizer(dataset=datasets.nist_all(), 
+    dataset = None
+    if rtt:
+        dataset = datasets.nist_all(maxsize=rtt)
+    else:
+        dataset = datasets.nist_all()
+
+    optimizer = SdaSgdOptimizer(dataset=dataset, 
                                     hyperparameters=state, \
                                     n_ins=n_ins, n_outs=n_outs,\
                                     examples_per_epoch=examples_per_epoch, \
                                     series=series,
-                                    max_minibatches=rtt)
+                                    save_params=SAVE_PARAMS)
 
-    optimizer.pretrain(datasets.nist_all())
+    optimizer.pretrain(dataset)
     channel.save()
 
-    optimizer.finetune(datasets.nist_all())
+    optimizer.finetune(dataset)
     channel.save()
 
     return channel.COMPLETE
--- a/deep/stacked_dae/sgd_optimization.py	Wed Mar 24 14:36:55 2010 -0400
+++ b/deep/stacked_dae/sgd_optimization.py	Wed Mar 24 15:13:48 2010 -0400
@@ -3,6 +3,8 @@
 
 # Generic SdA optimization loop, adapted from the deeplearning.net tutorial
 
+from __future__ import with_statement
+
 import numpy 
 import theano
 import time
@@ -25,22 +27,16 @@
         'params' : DummySeries()
         }
 
-def itermax(iter, max):
-    for i,it in enumerate(iter):
-        if i >= max:
-            break
-        yield it
-
 class SdaSgdOptimizer:
     def __init__(self, dataset, hyperparameters, n_ins, n_outs,
-                    examples_per_epoch, series=default_series, max_minibatches=None):
+                    examples_per_epoch, series=default_series, 
+                    save_params=False):
         self.dataset = dataset
         self.hp = hyperparameters
         self.n_ins = n_ins
         self.n_outs = n_outs
-   
-        self.max_minibatches = max_minibatches
-        print "SdaSgdOptimizer, max_minibatches =", max_minibatches
+
+        self.save_params = save_params
 
         self.ex_per_epoch = examples_per_epoch
         self.mb_per_epoch = examples_per_epoch / self.hp.minibatch_size
@@ -101,10 +97,6 @@
                     #if batch_index % 100 == 0:
                     #    print "100 batches"
 
-                    # useful when doing tests
-                    if self.max_minibatches and batch_index >= self.max_minibatches:
-                        break
-                        
                 print 'Pre-training layer %i, epoch %d, cost '%(i,epoch),c
                 sys.stdout.flush()
 
@@ -150,8 +142,6 @@
                                       # minibatche before checking the network 
                                       # on the validation set; in this case we 
                                       # check every epoch 
-        if self.max_minibatches and validation_frequency > self.max_minibatches:
-            validation_frequency = self.max_minibatches / 2
 
         best_params          = None
         best_validation_loss = float('inf')
@@ -176,8 +166,6 @@
                 if (total_mb_index+1) % validation_frequency == 0: 
                     
                     iter = dataset.valid(minibatch_size)
-                    if self.max_minibatches:
-                        iter = itermax(iter, self.max_minibatches)
                     validation_losses = [validate_model(x,y) for x,y in iter]
                     this_validation_loss = numpy.mean(validation_losses)
 
@@ -203,8 +191,6 @@
 
                         # test it on the test set
                         iter = dataset.test(minibatch_size)
-                        if self.max_minibatches:
-                            iter = itermax(iter, self.max_minibatches)
                         test_losses = [test_model(x,y) for x,y in iter]
                         test_score = numpy.mean(test_losses)
 
@@ -218,10 +204,6 @@
 
                     sys.stdout.flush()
 
-                # useful when doing tests
-                if self.max_minibatches and minibatch_index >= self.max_minibatches:
-                    break
-
             self.series['params'].append((epoch,), self.classifier.all_params)
 
             if patience <= total_mb_index:
@@ -234,6 +216,9 @@
                     'test_score':test_score,
                     'num_finetuning_epochs':epoch})
 
+        if self.save_params:
+            save_params(self.classifier.all_params, "weights.dat")
+
         print(('Optimization complete with best validation score of %f %%,'
                'with test performance %f %%') %  
                      (best_validation_loss * 100., test_score*100.))
@@ -241,3 +226,11 @@
 
 
 
+def save_params(all_params, filename):
+    import pickle
+    with open(filename, 'wb') as f:
+        values = [p.value for p in all_params]
+
+        # -1 for HIGHEST_PROTOCOL
+        pickle.dump(values, f, -1)
+