diff deep/amt/amt.py @ 399:99905d9bc9dd

Initial commit for calculating the test error of the AMT classifier
author humel
date Wed, 28 Apr 2010 00:38:31 -0400
parents
children 86d5e583e278
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/deep/amt/amt.py	Wed Apr 28 00:38:31 2010 -0400
@@ -0,0 +1,66 @@
+# Script usage : python amt.py filname.cvs
+
+import csv,numpy,re
+
+PATH = 'nist/'
+CVSFILE = None
+
+answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ]
+img_url = 'Input.image_url'
+turks_per_batch = 3
+
+
+def test_error():
+    turks = []
+    reader = csv.DictReader(open(CVSFILE), delimiter=',')
+    entries = [ turk for turk in reader ]
+
+    errors = numpy.zeros((len(entries),))
+    if len(entries) % turks_per_batch != 0 :
+        raise Exception('Wrong number of entries or turks_per_batch')
+
+    total_uniq_entries = len(entries) / turks_per_batch
+
+    errors = numpy.zeros((len(entries),))
+    error_means = numpy.zeros((total_uniq_entries,))
+    error_variances = numpy.zeros((total_uniq_entries,))
+
+
+    for i in range(total_uniq_entries):
+        for t in range(turks_per_batch):
+            errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t])
+        error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
+        error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
+
+    print 'Average test error: ' + str(error_means.mean()*10) +'%'
+    print 'Error variance    : ' + str(error_variances.mean()*10) +'%' 
+
+
+def get_error(entry):
+    file = parse_filename(entry[img_url])
+    f = open(PATH+file,'r')
+    labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
+    f.close()
+    test_error = 0
+    for i in range(len(answer_labels)):
+        try: 
+            answer = int(entry[answer_labels[i]])
+        except:
+            try :
+                answer = ord(entry[answer_labels[i]])
+            except:
+                test_error+=1
+                continue
+        if answer != int(labels[i]):
+            test_error+=1
+    return test_error
+
+
+def parse_filename(string):
+    filename = string.split('/')[-1]
+    return filename.split('.')[0]+'.txt'
+
+if __name__ =='__main__':
+    import sys
+    CVSFILE = sys.argv[1]
+    test_error()