diff deep/amt/amt.py @ 449:7bdd412754ea

Added support to calculate the human consensus error
author humel
date Sun, 09 May 2010 15:01:31 -0400
parents 5777b5041ac9
children
line wrap: on
line diff
--- a/deep/amt/amt.py	Fri May 07 17:24:21 2010 -0400
+++ b/deep/amt/amt.py	Sun May 09 15:01:31 2010 -0400
@@ -33,11 +33,28 @@
 """
 
 import csv,numpy,re,decimal
+from ift6266 import datasets
+from pylearn.io import filetensor as ft
+
+fnist = open('nist_train_class_freq.ft','r')
+fp07 = open('p07_train_class_freq.ft','r')
+fpnist = open('pnist_train_class_freq.ft','r')
+
+nist_freq_table = ft.read(fnist)
+p07_freq_table  = ft.read(fp07)
+pnist_freq_table  = ft.read(fpnist)
+
+fnist.close();fp07.close();fpnist.close()
 
 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/',
                  'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/',
                  'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' }
 
+freq_tables = { 'nist' : nist_freq_table,
+                'p07'  : p07_freq_table,
+                'pnist': pnist_freq_table }
+
+
 CVSFILE = None
 #PATH = None
 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ]
@@ -123,7 +140,7 @@
         raise ('Inapropriate option for the type of classification :' + type)
 
 
-def test_error(assoc_type=TYPE):
+def test_error(assoc_type=TYPE,consensus=True):
     answer_assoc = classes_answer(assoc_type)
 
     turks = []
@@ -136,16 +153,22 @@
 
     total_uniq_entries = len(entries) / turks_per_batch
 
-    errors = numpy.zeros((len(entries),))
-    num_examples = numpy.zeros((len(entries),))
-    error_means = numpy.zeros((total_uniq_entries,))
+
     error_variances = numpy.zeros((total_uniq_entries,))
-
-    for i in range(total_uniq_entries):
-        for t in range(turks_per_batch):
-            errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type)
-        #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
-        error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
+    
+    if consensus:
+        errors = numpy.zeros((total_uniq_entries,))
+        num_examples = numpy.zeros((total_uniq_entries,))
+        for i in range(total_uniq_entries):
+            errors[i],num_examples[i] = get_turk_consensus_error(entries[i*turks_per_batch:(i+1)*turks_per_batch],assoc_type)
+            error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
+    else:
+        errors = numpy.zeros((len(entries),))
+        num_examples = numpy.zeros((len(entries),))
+        for i in range(total_uniq_entries):
+            for t in range(turks_per_batch):
+                errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type)
+            error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
         
     percentage_error = 100. * errors.sum() / num_examples.sum()
     print 'Testing on         : ' + str(assoc_type)
@@ -176,6 +199,43 @@
                 test_error+=1
     return test_error,image_per_batch-cnt
 
+def get_turk_consensus_error(entries, type):
+    answer_assoc = classes_answer(type)
+    labels = get_labels(entries[0],type)
+    test_error = 0
+    cnt = 0
+    answer= []
+    freq_t = freq_tables[find_dataset(entries[0])]
+    for i in range(len(answer_labels)):
+        if labels[i] == -1:
+            cnt+=1
+        else:
+            answers = [ entry[answer_labels[i]] for entry in entries ]
+            if answers[0] != answers[1] and answers[1] != answers[2] and answers[0] != answers[2]:
+                m = max([ freq_t[answer_assoc[answer]] for answer in answers])
+                for answer in answers:
+                    if freq_t[answer_assoc[answer]] == m :
+                        a = answer
+            else:
+                for answer in answers:
+                    if answers.count(answer) > 1 :
+                        a =answer
+            try:
+                if answer_assoc[answer] != labels[i]:
+                    test_error+=1
+            except:
+                test_error+=1
+    return test_error,image_per_batch-cnt
+def frequency_table():
+    filenames = ['nist_train_class_freq.ft','p07_train_class_freq.ft','pnist_train_class_freq.ft']
+    iterators = [datasets.nist_all(),datasets.nist_P07(),datasets.PNIST07()]
+    for dataset,filename in zip(iterators,filenames):
+        freq_table = numpy.zeros(62)
+        for x,y in dataset.train(1):
+            freq_table[int(y)]+=1
+        f = open(filename,'w')
+        ft.write(f,freq_table)
+        f.close()
 
 def get_labels(entry,type):
     file = parse_filename(entry[img_url])
@@ -204,4 +264,6 @@
 if __name__ =='__main__':
     import sys
     CVSFILE = sys.argv[1]
-    test_error(sys.argv[2])
+    test_error(sys.argv[2],int(sys.argv[3]))
+    #frequency_table()
+