Mercurial > ift6266
comparison deep/amt/amt.py @ 449:7bdd412754ea
Added support to calculate the human consensus error
author | humel |
---|---|
date | Sun, 09 May 2010 15:01:31 -0400 |
parents | 5777b5041ac9 |
children |
comparison
equal
deleted
inserted
replaced
448:b2a7d93caa0f | 449:7bdd412754ea |
---|---|
31 Average test error : 50.9803921569% | 31 Average test error : 50.9803921569% |
32 Error variance : 1.33333333333% | 32 Error variance : 1.33333333333% |
33 """ | 33 """ |
34 | 34 |
35 import csv,numpy,re,decimal | 35 import csv,numpy,re,decimal |
36 from ift6266 import datasets | |
37 from pylearn.io import filetensor as ft | |
38 | |
39 fnist = open('nist_train_class_freq.ft','r') | |
40 fp07 = open('p07_train_class_freq.ft','r') | |
41 fpnist = open('pnist_train_class_freq.ft','r') | |
42 | |
43 nist_freq_table = ft.read(fnist) | |
44 p07_freq_table = ft.read(fp07) | |
45 pnist_freq_table = ft.read(fpnist) | |
46 | |
47 fnist.close();fp07.close();fpnist.close() | |
36 | 48 |
37 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', | 49 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', |
38 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', | 50 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', |
39 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } | 51 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } |
52 | |
53 freq_tables = { 'nist' : nist_freq_table, | |
54 'p07' : p07_freq_table, | |
55 'pnist': pnist_freq_table } | |
56 | |
40 | 57 |
41 CVSFILE = None | 58 CVSFILE = None |
42 #PATH = None | 59 #PATH = None |
43 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] | 60 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] |
44 img_url = 'Input.image_url' | 61 img_url = 'Input.image_url' |
121 return digit_classes_assoc() | 138 return digit_classes_assoc() |
122 else: | 139 else: |
123 raise ('Inapropriate option for the type of classification :' + type) | 140 raise ('Inapropriate option for the type of classification :' + type) |
124 | 141 |
125 | 142 |
126 def test_error(assoc_type=TYPE): | 143 def test_error(assoc_type=TYPE,consensus=True): |
127 answer_assoc = classes_answer(assoc_type) | 144 answer_assoc = classes_answer(assoc_type) |
128 | 145 |
129 turks = [] | 146 turks = [] |
130 reader = csv.DictReader(open(CVSFILE), delimiter=',') | 147 reader = csv.DictReader(open(CVSFILE), delimiter=',') |
131 entries = [ turk for turk in reader ] | 148 entries = [ turk for turk in reader ] |
134 if len(entries) % turks_per_batch != 0 : | 151 if len(entries) % turks_per_batch != 0 : |
135 raise Exception('Wrong number of entries or turks_per_batch') | 152 raise Exception('Wrong number of entries or turks_per_batch') |
136 | 153 |
137 total_uniq_entries = len(entries) / turks_per_batch | 154 total_uniq_entries = len(entries) / turks_per_batch |
138 | 155 |
139 errors = numpy.zeros((len(entries),)) | 156 |
140 num_examples = numpy.zeros((len(entries),)) | |
141 error_means = numpy.zeros((total_uniq_entries,)) | |
142 error_variances = numpy.zeros((total_uniq_entries,)) | 157 error_variances = numpy.zeros((total_uniq_entries,)) |
143 | 158 |
144 for i in range(total_uniq_entries): | 159 if consensus: |
145 for t in range(turks_per_batch): | 160 errors = numpy.zeros((total_uniq_entries,)) |
146 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) | 161 num_examples = numpy.zeros((total_uniq_entries,)) |
147 #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() | 162 for i in range(total_uniq_entries): |
148 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() | 163 errors[i],num_examples[i] = get_turk_consensus_error(entries[i*turks_per_batch:(i+1)*turks_per_batch],assoc_type) |
164 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() | |
165 else: | |
166 errors = numpy.zeros((len(entries),)) | |
167 num_examples = numpy.zeros((len(entries),)) | |
168 for i in range(total_uniq_entries): | |
169 for t in range(turks_per_batch): | |
170 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) | |
171 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() | |
149 | 172 |
150 percentage_error = 100. * errors.sum() / num_examples.sum() | 173 percentage_error = 100. * errors.sum() / num_examples.sum() |
151 print 'Testing on : ' + str(assoc_type) | 174 print 'Testing on : ' + str(assoc_type) |
152 print 'Total entries : ' + str(num_examples.sum()) | 175 print 'Total entries : ' + str(num_examples.sum()) |
153 print 'Turks per batch : ' + str(turks_per_batch) | 176 print 'Turks per batch : ' + str(turks_per_batch) |
174 test_error+=1 | 197 test_error+=1 |
175 except: | 198 except: |
176 test_error+=1 | 199 test_error+=1 |
177 return test_error,image_per_batch-cnt | 200 return test_error,image_per_batch-cnt |
178 | 201 |
202 def get_turk_consensus_error(entries, type): | |
203 answer_assoc = classes_answer(type) | |
204 labels = get_labels(entries[0],type) | |
205 test_error = 0 | |
206 cnt = 0 | |
207 answer= [] | |
208 freq_t = freq_tables[find_dataset(entries[0])] | |
209 for i in range(len(answer_labels)): | |
210 if labels[i] == -1: | |
211 cnt+=1 | |
212 else: | |
213 answers = [ entry[answer_labels[i]] for entry in entries ] | |
214 if answers[0] != answers[1] and answers[1] != answers[2] and answers[0] != answers[2]: | |
215 m = max([ freq_t[answer_assoc[answer]] for answer in answers]) | |
216 for answer in answers: | |
217 if freq_t[answer_assoc[answer]] == m : | |
218 a = answer | |
219 else: | |
220 for answer in answers: | |
221 if answers.count(answer) > 1 : | |
222 a =answer | |
223 try: | |
224 if answer_assoc[answer] != labels[i]: | |
225 test_error+=1 | |
226 except: | |
227 test_error+=1 | |
228 return test_error,image_per_batch-cnt | |
229 def frequency_table(): | |
230 filenames = ['nist_train_class_freq.ft','p07_train_class_freq.ft','pnist_train_class_freq.ft'] | |
231 iterators = [datasets.nist_all(),datasets.nist_P07(),datasets.PNIST07()] | |
232 for dataset,filename in zip(iterators,filenames): | |
233 freq_table = numpy.zeros(62) | |
234 for x,y in dataset.train(1): | |
235 freq_table[int(y)]+=1 | |
236 f = open(filename,'w') | |
237 ft.write(f,freq_table) | |
238 f.close() | |
179 | 239 |
180 def get_labels(entry,type): | 240 def get_labels(entry,type): |
181 file = parse_filename(entry[img_url]) | 241 file = parse_filename(entry[img_url]) |
182 path = DATASET_PATH[find_dataset(entry)] | 242 path = DATASET_PATH[find_dataset(entry)] |
183 f = open(path+file,'r') | 243 f = open(path+file,'r') |
202 return filename.split('.')[0]+'.txt' | 262 return filename.split('.')[0]+'.txt' |
203 | 263 |
204 if __name__ =='__main__': | 264 if __name__ =='__main__': |
205 import sys | 265 import sys |
206 CVSFILE = sys.argv[1] | 266 CVSFILE = sys.argv[1] |
207 test_error(sys.argv[2]) | 267 test_error(sys.argv[2],int(sys.argv[3])) |
268 #frequency_table() | |
269 |