changeset 412:6478eef4f8aa

Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
author humel
date Thu, 29 Apr 2010 14:28:52 -0400
parents 4f69d915d142
children f2dd75248483
files deep/amt/amt.py
diffstat 1 files changed, 141 insertions(+), 23 deletions(-) [+]
line wrap: on
line diff
--- a/deep/amt/amt.py	Thu Apr 29 13:18:15 2010 -0400
+++ b/deep/amt/amt.py	Thu Apr 29 14:28:52 2010 -0400
@@ -1,4 +1,36 @@
-# Script usage : python amt.py filname.cvs
+# Script usage : python amt.py filname.cvs type
+"""
+[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all
+Testing on         : all
+Total entries      : 300.0
+Turks per batch    : 3
+Average test error : 45.3333333333%
+Error variance     : 7.77777777778%
+[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36
+Testing on         : 36
+Total entries      : 300.0
+Turks per batch    : 3
+Average test error : 51.6666666667%
+Error variance     : 3.33333333333%
+[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper
+Testing on         : upper
+Total entries      : 63.0
+Turks per batch    : 3
+Average test error : 53.9682539683%
+Error variance     : 1.77777777778%
+[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower
+Testing on         : lower
+Total entries      : 135.0
+Turks per batch    : 3
+Average test error : 37.037037037%
+Error variance     : 3.77777777778%
+[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits
+Testing on         : digits
+Total entries      : 102.0
+Turks per batch    : 3
+Average test error : 50.9803921569%
+Error variance     : 1.33333333333%
+"""
 
 import csv,numpy,re,decimal
 
@@ -12,8 +44,9 @@
 img_url = 'Input.image_url'
 turks_per_batch = 3
 image_per_batch = 10
+TYPE = None
 
-def setup_association():
+def all_classes_assoc():
     answer_assoc = {}
     for i in range(0,10):
         answer_assoc[str(i)]=i
@@ -23,9 +56,74 @@
         answer_assoc[chr(i+61)]=i
     return answer_assoc
 
-answer_assoc = setup_association()
+def upper_classes_assoc():
+    answer_assoc = {}
+    for i in range(10,36):
+        answer_assoc[chr(i+55)]=i
+    return answer_assoc
+
+def lower_classes_assoc():
+    answer_assoc = {}
+    for i in range(36,62):
+        answer_assoc[chr(i+61)]=i
+    return answer_assoc
+
+def digit_classes_assoc():
+    answer_assoc = {}
+    for i in range(0,10):
+        answer_assoc[str(i)]=i
+    return answer_assoc
+
+def tsix_classes_assoc():
+    answer_assoc = {}
+    for i in range(10,36):
+        answer_assoc[chr(i+55)]=i
+        answer_assoc[chr(i+87)]=i
+    return answer_assoc
+
+def upper_label_assoc(ulabel):
+    for i in range(len(ulabel)):
+        if ulabel[i] < 10 or ulabel[i] > 35 :
+            ulabel[i] = -1
+    return ulabel
 
-def test_error():
+def lower_label_assoc(ulabel):
+    for i in range(len(ulabel)):
+        if ulabel[i] < 36 or ulabel[i] > 61 :
+            ulabel[i] = -1
+    return ulabel
+
+def tsix_label_assoc(ulabel): 
+    for i in range(len(ulabel)):
+        if ulabel[i] > 35 and ulabel[i] < 62 :
+            ulabel[i] = ulabel[i] - 26
+    return ulabel
+
+def digit_label_assoc(ulabel):
+    for i in range(len(ulabel)):
+        if ulabel[i] < 0 or ulabel[i] > 9 :
+            ulabel[i] = -1
+
+    return ulabel
+
+def classes_answer(type):
+    if type == 'all':
+        return all_classes_assoc()
+    elif type == '36':
+        return tsix_classes_assoc()
+    elif type == 'lower':
+        return lower_classes_assoc()
+    elif type == 'upper':
+        return upper_classes_assoc()
+    elif type == 'digits':
+        return digit_classes_assoc()
+    else:
+        raise ('Inapropriate option for the type of classification :' + type)
+
+
+def test_error(assoc_type=TYPE):
+    answer_assoc = classes_answer(assoc_type)
+
     turks = []
     reader = csv.DictReader(open(CVSFILE), delimiter=',')
     entries = [ turk for turk in reader ]
@@ -37,20 +135,21 @@
     total_uniq_entries = len(entries) / turks_per_batch
 
     errors = numpy.zeros((len(entries),))
+    num_examples = numpy.zeros((len(entries),))
     error_means = numpy.zeros((total_uniq_entries,))
     error_variances = numpy.zeros((total_uniq_entries,))
 
-    PATH = DATASET_PATH[find_dataset(entries[0])]
     for i in range(total_uniq_entries):
         for t in range(turks_per_batch):
-            errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH)
-        error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
+            errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type)
+        #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
         error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
-
-    decimal.getcontext().prec = 3
-    print 'Total entries      : ' + str(len(entries))
+        
+    percentage_error = 100. * errors.sum() / num_examples.sum()
+    print 'Testing on         : ' + str(assoc_type)
+    print 'Total entries      : ' + str(num_examples.sum())
     print 'Turks per batch    : ' + str(turks_per_batch)
-    print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%'
+    print 'Average test error : ' + str(percentage_error) +'%'
     print 'Error variance     : ' + str(error_variances.mean()*image_per_batch) +'%' 
 
 
@@ -58,24 +157,43 @@
     file = parse_filename(entry[img_url])
     return file.split('_')[0]
     
-def get_error(entry,PATH):
-    file = parse_filename(entry[img_url])
-    f = open(PATH+file,'r')
-    labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
-    f.close()
+def get_error(entry, type):
+    answer_assoc = classes_answer(type)
+    labels = get_labels(entry,type)
     test_error = 0
+    cnt = 0
     for i in range(len(answer_labels)):
-        answer = entry[answer_labels[i]]
-        if len(answer) != 0:
+        if labels[i] == -1:
+            cnt+=1
+        else:
+            answer = entry[answer_labels[i]]
             try:
-                if answer_assoc[answer] != int(labels[i]):
+                if answer_assoc[answer] != labels[i]:
                     test_error+=1
             except:
                 test_error+=1
-        else:
-            test_error+=1
-    return test_error
+    return test_error,image_per_batch-cnt
+
 
+def get_labels(entry,type):
+    file = parse_filename(entry[img_url])
+    path = DATASET_PATH[find_dataset(entry)]
+    f = open(path+file,'r')
+    str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
+    unrestricted_labels = [ int(element) for element in str_labels ]
+    if type == 'all':
+        return unrestricted_labels
+    elif type == '36':
+        return tsix_label_assoc(unrestricted_labels)
+    elif type == 'lower':
+        return lower_label_assoc(unrestricted_labels)
+    elif type == 'upper':
+        return upper_label_assoc(unrestricted_labels)
+    elif type == 'digits':
+        return digit_label_assoc(unrestricted_labels)
+    else:
+        raise ('Inapropriate option for the type of classification :' + str(type))
+    
 
 def parse_filename(string):
     filename = string.split('/')[-1]
@@ -84,4 +202,4 @@
 if __name__ =='__main__':
     import sys
     CVSFILE = sys.argv[1]
-    test_error()
+    test_error(sys.argv[2])