Mercurial > ift6266
comparison deep/amt/amt.py @ 412:6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
author | humel |
---|---|
date | Thu, 29 Apr 2010 14:28:52 -0400 |
parents | a11692910312 |
children | 5777b5041ac9 |
comparison
equal
deleted
inserted
replaced
411:4f69d915d142 | 412:6478eef4f8aa |
---|---|
1 # Script usage : python amt.py filname.cvs | 1 # Script usage : python amt.py filname.cvs type |
2 """ | |
3 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all | |
4 Testing on : all | |
5 Total entries : 300.0 | |
6 Turks per batch : 3 | |
7 Average test error : 45.3333333333% | |
8 Error variance : 7.77777777778% | |
9 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36 | |
10 Testing on : 36 | |
11 Total entries : 300.0 | |
12 Turks per batch : 3 | |
13 Average test error : 51.6666666667% | |
14 Error variance : 3.33333333333% | |
15 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper | |
16 Testing on : upper | |
17 Total entries : 63.0 | |
18 Turks per batch : 3 | |
19 Average test error : 53.9682539683% | |
20 Error variance : 1.77777777778% | |
21 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower | |
22 Testing on : lower | |
23 Total entries : 135.0 | |
24 Turks per batch : 3 | |
25 Average test error : 37.037037037% | |
26 Error variance : 3.77777777778% | |
27 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits | |
28 Testing on : digits | |
29 Total entries : 102.0 | |
30 Turks per batch : 3 | |
31 Average test error : 50.9803921569% | |
32 Error variance : 1.33333333333% | |
33 """ | |
2 | 34 |
3 import csv,numpy,re,decimal | 35 import csv,numpy,re,decimal |
4 | 36 |
5 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', | 37 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', |
6 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', | 38 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', |
10 #PATH = None | 42 #PATH = None |
11 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] | 43 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] |
12 img_url = 'Input.image_url' | 44 img_url = 'Input.image_url' |
13 turks_per_batch = 3 | 45 turks_per_batch = 3 |
14 image_per_batch = 10 | 46 image_per_batch = 10 |
15 | 47 TYPE = None |
16 def setup_association(): | 48 |
49 def all_classes_assoc(): | |
17 answer_assoc = {} | 50 answer_assoc = {} |
18 for i in range(0,10): | 51 for i in range(0,10): |
19 answer_assoc[str(i)]=i | 52 answer_assoc[str(i)]=i |
20 for i in range(10,36): | 53 for i in range(10,36): |
21 answer_assoc[chr(i+55)]=i | 54 answer_assoc[chr(i+55)]=i |
22 for i in range(36,62): | 55 for i in range(36,62): |
23 answer_assoc[chr(i+61)]=i | 56 answer_assoc[chr(i+61)]=i |
24 return answer_assoc | 57 return answer_assoc |
25 | 58 |
26 answer_assoc = setup_association() | 59 def upper_classes_assoc(): |
27 | 60 answer_assoc = {} |
28 def test_error(): | 61 for i in range(10,36): |
62 answer_assoc[chr(i+55)]=i | |
63 return answer_assoc | |
64 | |
65 def lower_classes_assoc(): | |
66 answer_assoc = {} | |
67 for i in range(36,62): | |
68 answer_assoc[chr(i+61)]=i | |
69 return answer_assoc | |
70 | |
71 def digit_classes_assoc(): | |
72 answer_assoc = {} | |
73 for i in range(0,10): | |
74 answer_assoc[str(i)]=i | |
75 return answer_assoc | |
76 | |
77 def tsix_classes_assoc(): | |
78 answer_assoc = {} | |
79 for i in range(10,36): | |
80 answer_assoc[chr(i+55)]=i | |
81 answer_assoc[chr(i+87)]=i | |
82 return answer_assoc | |
83 | |
84 def upper_label_assoc(ulabel): | |
85 for i in range(len(ulabel)): | |
86 if ulabel[i] < 10 or ulabel[i] > 35 : | |
87 ulabel[i] = -1 | |
88 return ulabel | |
89 | |
90 def lower_label_assoc(ulabel): | |
91 for i in range(len(ulabel)): | |
92 if ulabel[i] < 36 or ulabel[i] > 61 : | |
93 ulabel[i] = -1 | |
94 return ulabel | |
95 | |
96 def tsix_label_assoc(ulabel): | |
97 for i in range(len(ulabel)): | |
98 if ulabel[i] > 35 and ulabel[i] < 62 : | |
99 ulabel[i] = ulabel[i] - 26 | |
100 return ulabel | |
101 | |
102 def digit_label_assoc(ulabel): | |
103 for i in range(len(ulabel)): | |
104 if ulabel[i] < 0 or ulabel[i] > 9 : | |
105 ulabel[i] = -1 | |
106 | |
107 return ulabel | |
108 | |
109 def classes_answer(type): | |
110 if type == 'all': | |
111 return all_classes_assoc() | |
112 elif type == '36': | |
113 return tsix_classes_assoc() | |
114 elif type == 'lower': | |
115 return lower_classes_assoc() | |
116 elif type == 'upper': | |
117 return upper_classes_assoc() | |
118 elif type == 'digits': | |
119 return digit_classes_assoc() | |
120 else: | |
121 raise ('Inapropriate option for the type of classification :' + type) | |
122 | |
123 | |
124 def test_error(assoc_type=TYPE): | |
125 answer_assoc = classes_answer(assoc_type) | |
126 | |
29 turks = [] | 127 turks = [] |
30 reader = csv.DictReader(open(CVSFILE), delimiter=',') | 128 reader = csv.DictReader(open(CVSFILE), delimiter=',') |
31 entries = [ turk for turk in reader ] | 129 entries = [ turk for turk in reader ] |
32 | 130 |
33 errors = numpy.zeros((len(entries),)) | 131 errors = numpy.zeros((len(entries),)) |
35 raise Exception('Wrong number of entries or turks_per_batch') | 133 raise Exception('Wrong number of entries or turks_per_batch') |
36 | 134 |
37 total_uniq_entries = len(entries) / turks_per_batch | 135 total_uniq_entries = len(entries) / turks_per_batch |
38 | 136 |
39 errors = numpy.zeros((len(entries),)) | 137 errors = numpy.zeros((len(entries),)) |
138 num_examples = numpy.zeros((len(entries),)) | |
40 error_means = numpy.zeros((total_uniq_entries,)) | 139 error_means = numpy.zeros((total_uniq_entries,)) |
41 error_variances = numpy.zeros((total_uniq_entries,)) | 140 error_variances = numpy.zeros((total_uniq_entries,)) |
42 | 141 |
43 PATH = DATASET_PATH[find_dataset(entries[0])] | |
44 for i in range(total_uniq_entries): | 142 for i in range(total_uniq_entries): |
45 for t in range(turks_per_batch): | 143 for t in range(turks_per_batch): |
46 errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH) | 144 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) |
47 error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() | 145 #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() |
48 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() | 146 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() |
49 | 147 |
50 decimal.getcontext().prec = 3 | 148 percentage_error = 100. * errors.sum() / num_examples.sum() |
51 print 'Total entries : ' + str(len(entries)) | 149 print 'Testing on : ' + str(assoc_type) |
150 print 'Total entries : ' + str(num_examples.sum()) | |
52 print 'Turks per batch : ' + str(turks_per_batch) | 151 print 'Turks per batch : ' + str(turks_per_batch) |
53 print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%' | 152 print 'Average test error : ' + str(percentage_error) +'%' |
54 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' | 153 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' |
55 | 154 |
56 | 155 |
57 def find_dataset(entry): | 156 def find_dataset(entry): |
58 file = parse_filename(entry[img_url]) | 157 file = parse_filename(entry[img_url]) |
59 return file.split('_')[0] | 158 return file.split('_')[0] |
60 | 159 |
61 def get_error(entry,PATH): | 160 def get_error(entry, type): |
62 file = parse_filename(entry[img_url]) | 161 answer_assoc = classes_answer(type) |
63 f = open(PATH+file,'r') | 162 labels = get_labels(entry,type) |
64 labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') | |
65 f.close() | |
66 test_error = 0 | 163 test_error = 0 |
164 cnt = 0 | |
67 for i in range(len(answer_labels)): | 165 for i in range(len(answer_labels)): |
68 answer = entry[answer_labels[i]] | 166 if labels[i] == -1: |
69 if len(answer) != 0: | 167 cnt+=1 |
168 else: | |
169 answer = entry[answer_labels[i]] | |
70 try: | 170 try: |
71 if answer_assoc[answer] != int(labels[i]): | 171 if answer_assoc[answer] != labels[i]: |
72 test_error+=1 | 172 test_error+=1 |
73 except: | 173 except: |
74 test_error+=1 | 174 test_error+=1 |
75 else: | 175 return test_error,image_per_batch-cnt |
76 test_error+=1 | 176 |
77 return test_error | 177 |
78 | 178 def get_labels(entry,type): |
179 file = parse_filename(entry[img_url]) | |
180 path = DATASET_PATH[find_dataset(entry)] | |
181 f = open(path+file,'r') | |
182 str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') | |
183 unrestricted_labels = [ int(element) for element in str_labels ] | |
184 if type == 'all': | |
185 return unrestricted_labels | |
186 elif type == '36': | |
187 return tsix_label_assoc(unrestricted_labels) | |
188 elif type == 'lower': | |
189 return lower_label_assoc(unrestricted_labels) | |
190 elif type == 'upper': | |
191 return upper_label_assoc(unrestricted_labels) | |
192 elif type == 'digits': | |
193 return digit_label_assoc(unrestricted_labels) | |
194 else: | |
195 raise ('Inapropriate option for the type of classification :' + str(type)) | |
196 | |
79 | 197 |
80 def parse_filename(string): | 198 def parse_filename(string): |
81 filename = string.split('/')[-1] | 199 filename = string.split('/')[-1] |
82 return filename.split('.')[0]+'.txt' | 200 return filename.split('.')[0]+'.txt' |
83 | 201 |
84 if __name__ =='__main__': | 202 if __name__ =='__main__': |
85 import sys | 203 import sys |
86 CVSFILE = sys.argv[1] | 204 CVSFILE = sys.argv[1] |
87 test_error() | 205 test_error(sys.argv[2]) |