diff deep/amt/amt.py @ 402:83413ac10913

Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
author humel
date Wed, 28 Apr 2010 11:28:28 -0400
parents 86d5e583e278
children a11692910312
line wrap: on
line diff
--- a/deep/amt/amt.py	Wed Apr 28 01:25:52 2010 -0400
+++ b/deep/amt/amt.py	Wed Apr 28 11:28:28 2010 -0400
@@ -1,14 +1,17 @@
 # Script usage : python amt.py filname.cvs
 
-import csv,numpy,re
+import csv,numpy,re,decimal
 
-PATH = 'nist/'
-CVSFILE = None
+DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/',
+                 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/',
+                 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' }
 
-answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ]
-img_url = 'Input.image_url'
+CVSFILE = None
+#PATH = None
+answer_labels = [ 'answer.c'+str(i+1) for i in range(10) ]
+img_url = 'input.image_url'
 turks_per_batch = 3
-
+image_per_batch = 10
 
 def setup_association():
     answer_assoc = {}
@@ -37,18 +40,25 @@
     error_means = numpy.zeros((total_uniq_entries,))
     error_variances = numpy.zeros((total_uniq_entries,))
 
-
+    PATH = DATASET_PATH[find_dataset(entries[0])]
     for i in range(total_uniq_entries):
         for t in range(turks_per_batch):
-            errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t])
+            errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH)
         error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
         error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
 
-    print 'Average test error: ' + str(error_means.mean()*10) +'%'
-    print 'Error variance    : ' + str(error_variances.mean()*10) +'%' 
+    decimal.getcontext().prec = 3
+    print 'Total entries      : ' + str(len(entries))
+    print 'Turks per batch    : ' + str(turks_per_batch)
+    print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%'
+    print 'Error variance     : ' + str(error_variances.mean()*image_per_batch) +'%' 
 
 
-def get_error(entry):
+def find_dataset(entry):
+    file = parse_filename(entry[img_url])
+    return file.split('_')[0]
+    
+def get_error(entry,PATH):
     file = parse_filename(entry[img_url])
     f = open(PATH+file,'r')
     labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')