diff deep/stacked_dae/nist_sda.py @ 284:8a3af19ae272

Enlevé mécanique pour limiter le nombre d'exemples utilisés (remplacé par paramètre dans l'appel au code de dataset), et ajouté option pour sauvegarde des poids à la fin de l'entraînement
author fsavard
date Wed, 24 Mar 2010 15:13:48 -0400
parents 206374eed2fb
children
line wrap: on
line diff
--- a/deep/stacked_dae/nist_sda.py	Wed Mar 24 14:36:55 2010 -0400
+++ b/deep/stacked_dae/nist_sda.py	Wed Mar 24 15:13:48 2010 -0400
@@ -56,22 +56,30 @@
     n_outs = 62 # 10 digits, 26*2 (lower, capitals)
      
     examples_per_epoch = NIST_ALL_TRAIN_SIZE
+    if rtt:
+        examples_per_epoch = rtt
 
     series = create_series(state.num_hidden_layers)
 
     print "Creating optimizer with state, ", state
 
-    optimizer = SdaSgdOptimizer(dataset=datasets.nist_all(), 
+    dataset = None
+    if rtt:
+        dataset = datasets.nist_all(maxsize=rtt)
+    else:
+        dataset = datasets.nist_all()
+
+    optimizer = SdaSgdOptimizer(dataset=dataset, 
                                     hyperparameters=state, \
                                     n_ins=n_ins, n_outs=n_outs,\
                                     examples_per_epoch=examples_per_epoch, \
                                     series=series,
-                                    max_minibatches=rtt)
+                                    save_params=SAVE_PARAMS)
 
-    optimizer.pretrain(datasets.nist_all())
+    optimizer.pretrain(dataset)
     channel.save()
 
-    optimizer.finetune(datasets.nist_all())
+    optimizer.finetune(dataset)
     channel.save()
 
     return channel.COMPLETE