changeset 317:067e747fd9c0

Ajout de noms differents pour les series produites pour differents choix de pretrain
author SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca>
date Thu, 01 Apr 2010 20:09:14 -0400
parents 60e82846a10d
children 8de3bef71458
files deep/stacked_dae/v_sylvain/nist_sda_retrieve.py
diffstat 1 files changed, 15 insertions(+), 9 deletions(-) [+]
line wrap: on
line diff
--- a/deep/stacked_dae/v_sylvain/nist_sda_retrieve.py	Thu Apr 01 17:47:53 2010 -0400
+++ b/deep/stacked_dae/v_sylvain/nist_sda_retrieve.py	Thu Apr 01 20:09:14 2010 -0400
@@ -55,8 +55,20 @@
     n_outs = 62 # 10 digits, 26*2 (lower, capitals)
      
     examples_per_epoch = NIST_ALL_TRAIN_SIZE
+    #To be sure variables will not be only in the if statement
+    PATH = ''
+    nom_reptrain = ''
+    nom_serie = ""
+    if state['pretrain_choice'] == 0:
+        PATH=PATH_NIST
+        nom_pretrain='NIST'
+        nom_serie="series_NIST.h5"
+    elif state['pretrain_choice'] == 1:
+        PATH=PATH_P07
+        nom_pretrain='P07'
+        nom_serie="series_P07.h5"
 
-    series = create_series(state.num_hidden_layers)
+    series = create_series(state.num_hidden_layers,nom_serie)
 
     print "Creating optimizer with state, ", state
 
@@ -83,12 +95,6 @@
 ##             "or reduce the number of pretraining epoch to run the code (better idea).\n")
 ##        print('\n\tpretraining with P07')
 ##        optimizer.pretrain(datasets.nist_P07(min_file=0,max_file=nb_file)) 
-    if state['pretrain_choice'] == 0:
-        PATH=PATH_NIST
-        nom_pretrain='NIST'
-    elif state['pretrain_choice'] == 1:
-        PATH=PATH_P07
-        nom_pretrain='P07'
     
     print ('Retrieve pre-train done earlier ( '+nom_pretrain+' )')
     
@@ -168,7 +174,7 @@
 
 # These Series objects are used to save various statistics
 # during the training.
-def create_series(num_hidden_layers):
+def create_series(num_hidden_layers, nom_serie):
 
     # Replace series we don't want to save with DummySeries, e.g.
     # series['training_error'] = DummySeries()
@@ -177,7 +183,7 @@
 
     basedir = os.getcwd()
 
-    h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w")
+    h5f = tables.openFile(os.path.join(basedir, nom_serie), "w")
 
     # reconstruction
     reconstruction_base = \