Mercurial > ift6266
diff deep/amt/amt.py @ 449:7bdd412754ea
Added support to calculate the human consensus error
author | humel |
---|---|
date | Sun, 09 May 2010 15:01:31 -0400 |
parents | 5777b5041ac9 |
children |
line wrap: on
line diff
--- a/deep/amt/amt.py Fri May 07 17:24:21 2010 -0400 +++ b/deep/amt/amt.py Sun May 09 15:01:31 2010 -0400 @@ -33,11 +33,28 @@ """ import csv,numpy,re,decimal +from ift6266 import datasets +from pylearn.io import filetensor as ft + +fnist = open('nist_train_class_freq.ft','r') +fp07 = open('p07_train_class_freq.ft','r') +fpnist = open('pnist_train_class_freq.ft','r') + +nist_freq_table = ft.read(fnist) +p07_freq_table = ft.read(fp07) +pnist_freq_table = ft.read(fpnist) + +fnist.close();fp07.close();fpnist.close() DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } +freq_tables = { 'nist' : nist_freq_table, + 'p07' : p07_freq_table, + 'pnist': pnist_freq_table } + + CVSFILE = None #PATH = None answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] @@ -123,7 +140,7 @@ raise ('Inapropriate option for the type of classification :' + type) -def test_error(assoc_type=TYPE): +def test_error(assoc_type=TYPE,consensus=True): answer_assoc = classes_answer(assoc_type) turks = [] @@ -136,16 +153,22 @@ total_uniq_entries = len(entries) / turks_per_batch - errors = numpy.zeros((len(entries),)) - num_examples = numpy.zeros((len(entries),)) - error_means = numpy.zeros((total_uniq_entries,)) + error_variances = numpy.zeros((total_uniq_entries,)) - - for i in range(total_uniq_entries): - for t in range(turks_per_batch): - errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) - #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() - error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() + + if consensus: + errors = numpy.zeros((total_uniq_entries,)) + num_examples = numpy.zeros((total_uniq_entries,)) + for i in range(total_uniq_entries): + errors[i],num_examples[i] = get_turk_consensus_error(entries[i*turks_per_batch:(i+1)*turks_per_batch],assoc_type) + error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() + else: + errors = numpy.zeros((len(entries),)) + num_examples = numpy.zeros((len(entries),)) + for i in range(total_uniq_entries): + for t in range(turks_per_batch): + errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) + error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() percentage_error = 100. * errors.sum() / num_examples.sum() print 'Testing on : ' + str(assoc_type) @@ -176,6 +199,43 @@ test_error+=1 return test_error,image_per_batch-cnt +def get_turk_consensus_error(entries, type): + answer_assoc = classes_answer(type) + labels = get_labels(entries[0],type) + test_error = 0 + cnt = 0 + answer= [] + freq_t = freq_tables[find_dataset(entries[0])] + for i in range(len(answer_labels)): + if labels[i] == -1: + cnt+=1 + else: + answers = [ entry[answer_labels[i]] for entry in entries ] + if answers[0] != answers[1] and answers[1] != answers[2] and answers[0] != answers[2]: + m = max([ freq_t[answer_assoc[answer]] for answer in answers]) + for answer in answers: + if freq_t[answer_assoc[answer]] == m : + a = answer + else: + for answer in answers: + if answers.count(answer) > 1 : + a =answer + try: + if answer_assoc[answer] != labels[i]: + test_error+=1 + except: + test_error+=1 + return test_error,image_per_batch-cnt +def frequency_table(): + filenames = ['nist_train_class_freq.ft','p07_train_class_freq.ft','pnist_train_class_freq.ft'] + iterators = [datasets.nist_all(),datasets.nist_P07(),datasets.PNIST07()] + for dataset,filename in zip(iterators,filenames): + freq_table = numpy.zeros(62) + for x,y in dataset.train(1): + freq_table[int(y)]+=1 + f = open(filename,'w') + ft.write(f,freq_table) + f.close() def get_labels(entry,type): file = parse_filename(entry[img_url]) @@ -204,4 +264,6 @@ if __name__ =='__main__': import sys CVSFILE = sys.argv[1] - test_error(sys.argv[2]) + test_error(sys.argv[2],int(sys.argv[3])) + #frequency_table() +