diff deep/stacked_dae/v_sylvain/nist_byclass_error.py @ 446:9fb893d039c6

Ajout de fonctionnalite pour calculer les erreurs de validation avec PNIST et P07
author SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca>
date Tue, 04 May 2010 19:59:14 -0400
parents eb42bed0c13b
children
line wrap: on
line diff
--- a/deep/stacked_dae/v_sylvain/nist_byclass_error.py	Tue May 04 11:17:27 2010 -0400
+++ b/deep/stacked_dae/v_sylvain/nist_byclass_error.py	Tue May 04 19:59:14 2010 -0400
@@ -57,6 +57,7 @@
     examples_per_epoch = NIST_ALL_TRAIN_SIZE
 
     PATH = ''
+    NIST_BY_CLASS=0
 
 
 
@@ -75,64 +76,100 @@
     if os.path.exists(PATH+'params_finetune_NIST.txt'):
         print ('\n finetune = NIST ')
         optimizer.reload_parameters(PATH+'params_finetune_NIST.txt')
-        print "NIST DIGITS"
-        optimizer.training_error(datasets.nist_digits(),part=2)
-        print "NIST LOWER CASE"
-        optimizer.training_error(datasets.nist_lower(),part=2)
-        print "NIST UPPER CASE"
-        optimizer.training_error(datasets.nist_upper(),part=2)
+        if NIST_BY_CLASS == 1:
+            print "NIST DIGITS"
+            optimizer.training_error(datasets.nist_digits(),part=2)
+            print "NIST LOWER CASE"
+            optimizer.training_error(datasets.nist_lower(),part=2)
+            print "NIST UPPER CASE"
+            optimizer.training_error(datasets.nist_upper(),part=2)
+        else:
+            print "P07 valid"
+            optimizer.training_error(datasets.nist_P07(),part=1)
+            print "PNIST valid"
+            optimizer.training_error(datasets.PNIST07(),part=1)
         
     
     if os.path.exists(PATH+'params_finetune_P07.txt'):
         print ('\n finetune = P07 ')
         optimizer.reload_parameters(PATH+'params_finetune_P07.txt')
-        print "NIST DIGITS"
-        optimizer.training_error(datasets.nist_digits(),part=2)
-        print "NIST LOWER CASE"
-        optimizer.training_error(datasets.nist_lower(),part=2)
-        print "NIST UPPER CASE"
-        optimizer.training_error(datasets.nist_upper(),part=2)
+        if NIST_BY_CLASS == 1:
+            print "NIST DIGITS"
+            optimizer.training_error(datasets.nist_digits(),part=2)
+            print "NIST LOWER CASE"
+            optimizer.training_error(datasets.nist_lower(),part=2)
+            print "NIST UPPER CASE"
+            optimizer.training_error(datasets.nist_upper(),part=2)
+        else:
+            print "P07 valid"
+            optimizer.training_error(datasets.nist_P07(),part=1)
+            print "PNIST valid"
+            optimizer.training_error(datasets.PNIST07(),part=1)
 
     
     if os.path.exists(PATH+'params_finetune_NIST_then_P07.txt'):
         print ('\n finetune = NIST then P07')
         optimizer.reload_parameters(PATH+'params_finetune_NIST_then_P07.txt')
-        print "NIST DIGITS"
-        optimizer.training_error(datasets.nist_digits(),part=2)
-        print "NIST LOWER CASE"
-        optimizer.training_error(datasets.nist_lower(),part=2)
-        print "NIST UPPER CASE"
-        optimizer.training_error(datasets.nist_upper(),part=2)
+        if NIST_BY_CLASS == 1:
+            print "NIST DIGITS"
+            optimizer.training_error(datasets.nist_digits(),part=2)
+            print "NIST LOWER CASE"
+            optimizer.training_error(datasets.nist_lower(),part=2)
+            print "NIST UPPER CASE"
+            optimizer.training_error(datasets.nist_upper(),part=2)
+        else:
+            print "P07 valid"
+            optimizer.training_error(datasets.nist_P07(),part=1)
+            print "PNIST valid"
+            optimizer.training_error(datasets.PNIST07(),part=1)
     
     if os.path.exists(PATH+'params_finetune_P07_then_NIST.txt'):
         print ('\n finetune = P07 then NIST')
         optimizer.reload_parameters(PATH+'params_finetune_P07_then_NIST.txt')
-        print "NIST DIGITS"
-        optimizer.training_error(datasets.nist_digits(),part=2)
-        print "NIST LOWER CASE"
-        optimizer.training_error(datasets.nist_lower(),part=2)
-        print "NIST UPPER CASE"
-        optimizer.training_error(datasets.nist_upper(),part=2)
+        if NIST_BY_CLASS == 1:
+            print "NIST DIGITS"
+            optimizer.training_error(datasets.nist_digits(),part=2)
+            print "NIST LOWER CASE"
+            optimizer.training_error(datasets.nist_lower(),part=2)
+            print "NIST UPPER CASE"
+            optimizer.training_error(datasets.nist_upper(),part=2)
+        else:
+            print "P07 valid"
+            optimizer.training_error(datasets.nist_P07(),part=1)
+            print "PNIST valid"
+            optimizer.training_error(datasets.PNIST07(),part=1)
     
     if os.path.exists(PATH+'params_finetune_PNIST07.txt'):
         print ('\n finetune = PNIST07')
         optimizer.reload_parameters(PATH+'params_finetune_PNIST07.txt')
-        print "NIST DIGITS"
-        optimizer.training_error(datasets.nist_digits(),part=2)
-        print "NIST LOWER CASE"
-        optimizer.training_error(datasets.nist_lower(),part=2)
-        print "NIST UPPER CASE"
-        optimizer.training_error(datasets.nist_upper(),part=2)
+        if NIST_BY_CLASS == 1:
+            print "NIST DIGITS"
+            optimizer.training_error(datasets.nist_digits(),part=2)
+            print "NIST LOWER CASE"
+            optimizer.training_error(datasets.nist_lower(),part=2)
+            print "NIST UPPER CASE"
+            optimizer.training_error(datasets.nist_upper(),part=2)
+        else:
+            print "P07 valid"
+            optimizer.training_error(datasets.nist_P07(),part=1)
+            print "PNIST valid"
+            optimizer.training_error(datasets.PNIST07(),part=1)
         
     if os.path.exists(PATH+'params_finetune_PNIST07_then_NIST.txt'):
         print ('\n finetune = PNIST07 then NIST')
         optimizer.reload_parameters(PATH+'params_finetune_PNIST07_then_NIST.txt')
-        print "NIST DIGITS"
-        optimizer.training_error(datasets.nist_digits(),part=2)
-        print "NIST LOWER CASE"
-        optimizer.training_error(datasets.nist_lower(),part=2)
-        print "NIST UPPER CASE"
-        optimizer.training_error(datasets.nist_upper(),part=2)
+        if NIST_BY_CLASS == 1:
+            print "NIST DIGITS"
+            optimizer.training_error(datasets.nist_digits(),part=2)
+            print "NIST LOWER CASE"
+            optimizer.training_error(datasets.nist_lower(),part=2)
+            print "NIST UPPER CASE"
+            optimizer.training_error(datasets.nist_upper(),part=2)
+        else:
+            print "P07 valid"
+            optimizer.training_error(datasets.nist_P07(),part=1)
+            print "PNIST valid"
+            optimizer.training_error(datasets.PNIST07(),part=1)
     
     channel.save()