comparison deep/amt/amt.py @ 449:7bdd412754ea

Added support to calculate the human consensus error
author humel
date Sun, 09 May 2010 15:01:31 -0400
parents 5777b5041ac9
children
comparison
equal deleted inserted replaced
448:b2a7d93caa0f 449:7bdd412754ea
31 Average test error : 50.9803921569% 31 Average test error : 50.9803921569%
32 Error variance : 1.33333333333% 32 Error variance : 1.33333333333%
33 """ 33 """
34 34
35 import csv,numpy,re,decimal 35 import csv,numpy,re,decimal
36 from ift6266 import datasets
37 from pylearn.io import filetensor as ft
38
39 fnist = open('nist_train_class_freq.ft','r')
40 fp07 = open('p07_train_class_freq.ft','r')
41 fpnist = open('pnist_train_class_freq.ft','r')
42
43 nist_freq_table = ft.read(fnist)
44 p07_freq_table = ft.read(fp07)
45 pnist_freq_table = ft.read(fpnist)
46
47 fnist.close();fp07.close();fpnist.close()
36 48
37 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', 49 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/',
38 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', 50 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/',
39 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } 51 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' }
52
53 freq_tables = { 'nist' : nist_freq_table,
54 'p07' : p07_freq_table,
55 'pnist': pnist_freq_table }
56
40 57
41 CVSFILE = None 58 CVSFILE = None
42 #PATH = None 59 #PATH = None
43 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] 60 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ]
44 img_url = 'Input.image_url' 61 img_url = 'Input.image_url'
121 return digit_classes_assoc() 138 return digit_classes_assoc()
122 else: 139 else:
123 raise ('Inapropriate option for the type of classification :' + type) 140 raise ('Inapropriate option for the type of classification :' + type)
124 141
125 142
126 def test_error(assoc_type=TYPE): 143 def test_error(assoc_type=TYPE,consensus=True):
127 answer_assoc = classes_answer(assoc_type) 144 answer_assoc = classes_answer(assoc_type)
128 145
129 turks = [] 146 turks = []
130 reader = csv.DictReader(open(CVSFILE), delimiter=',') 147 reader = csv.DictReader(open(CVSFILE), delimiter=',')
131 entries = [ turk for turk in reader ] 148 entries = [ turk for turk in reader ]
134 if len(entries) % turks_per_batch != 0 : 151 if len(entries) % turks_per_batch != 0 :
135 raise Exception('Wrong number of entries or turks_per_batch') 152 raise Exception('Wrong number of entries or turks_per_batch')
136 153
137 total_uniq_entries = len(entries) / turks_per_batch 154 total_uniq_entries = len(entries) / turks_per_batch
138 155
139 errors = numpy.zeros((len(entries),)) 156
140 num_examples = numpy.zeros((len(entries),))
141 error_means = numpy.zeros((total_uniq_entries,))
142 error_variances = numpy.zeros((total_uniq_entries,)) 157 error_variances = numpy.zeros((total_uniq_entries,))
143 158
144 for i in range(total_uniq_entries): 159 if consensus:
145 for t in range(turks_per_batch): 160 errors = numpy.zeros((total_uniq_entries,))
146 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) 161 num_examples = numpy.zeros((total_uniq_entries,))
147 #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() 162 for i in range(total_uniq_entries):
148 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() 163 errors[i],num_examples[i] = get_turk_consensus_error(entries[i*turks_per_batch:(i+1)*turks_per_batch],assoc_type)
164 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
165 else:
166 errors = numpy.zeros((len(entries),))
167 num_examples = numpy.zeros((len(entries),))
168 for i in range(total_uniq_entries):
169 for t in range(turks_per_batch):
170 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type)
171 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
149 172
150 percentage_error = 100. * errors.sum() / num_examples.sum() 173 percentage_error = 100. * errors.sum() / num_examples.sum()
151 print 'Testing on : ' + str(assoc_type) 174 print 'Testing on : ' + str(assoc_type)
152 print 'Total entries : ' + str(num_examples.sum()) 175 print 'Total entries : ' + str(num_examples.sum())
153 print 'Turks per batch : ' + str(turks_per_batch) 176 print 'Turks per batch : ' + str(turks_per_batch)
174 test_error+=1 197 test_error+=1
175 except: 198 except:
176 test_error+=1 199 test_error+=1
177 return test_error,image_per_batch-cnt 200 return test_error,image_per_batch-cnt
178 201
202 def get_turk_consensus_error(entries, type):
203 answer_assoc = classes_answer(type)
204 labels = get_labels(entries[0],type)
205 test_error = 0
206 cnt = 0
207 answer= []
208 freq_t = freq_tables[find_dataset(entries[0])]
209 for i in range(len(answer_labels)):
210 if labels[i] == -1:
211 cnt+=1
212 else:
213 answers = [ entry[answer_labels[i]] for entry in entries ]
214 if answers[0] != answers[1] and answers[1] != answers[2] and answers[0] != answers[2]:
215 m = max([ freq_t[answer_assoc[answer]] for answer in answers])
216 for answer in answers:
217 if freq_t[answer_assoc[answer]] == m :
218 a = answer
219 else:
220 for answer in answers:
221 if answers.count(answer) > 1 :
222 a =answer
223 try:
224 if answer_assoc[answer] != labels[i]:
225 test_error+=1
226 except:
227 test_error+=1
228 return test_error,image_per_batch-cnt
229 def frequency_table():
230 filenames = ['nist_train_class_freq.ft','p07_train_class_freq.ft','pnist_train_class_freq.ft']
231 iterators = [datasets.nist_all(),datasets.nist_P07(),datasets.PNIST07()]
232 for dataset,filename in zip(iterators,filenames):
233 freq_table = numpy.zeros(62)
234 for x,y in dataset.train(1):
235 freq_table[int(y)]+=1
236 f = open(filename,'w')
237 ft.write(f,freq_table)
238 f.close()
179 239
180 def get_labels(entry,type): 240 def get_labels(entry,type):
181 file = parse_filename(entry[img_url]) 241 file = parse_filename(entry[img_url])
182 path = DATASET_PATH[find_dataset(entry)] 242 path = DATASET_PATH[find_dataset(entry)]
183 f = open(path+file,'r') 243 f = open(path+file,'r')
202 return filename.split('.')[0]+'.txt' 262 return filename.split('.')[0]+'.txt'
203 263
204 if __name__ =='__main__': 264 if __name__ =='__main__':
205 import sys 265 import sys
206 CVSFILE = sys.argv[1] 266 CVSFILE = sys.argv[1]
207 test_error(sys.argv[2]) 267 test_error(sys.argv[2],int(sys.argv[3]))
268 #frequency_table()
269