comparison deep/amt/amt.py @ 412:6478eef4f8aa

Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
author humel
date Thu, 29 Apr 2010 14:28:52 -0400
parents a11692910312
children 5777b5041ac9
comparison
equal deleted inserted replaced
411:4f69d915d142 412:6478eef4f8aa
1 # Script usage : python amt.py filname.cvs 1 # Script usage : python amt.py filname.cvs type
2 """
3 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all
4 Testing on : all
5 Total entries : 300.0
6 Turks per batch : 3
7 Average test error : 45.3333333333%
8 Error variance : 7.77777777778%
9 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36
10 Testing on : 36
11 Total entries : 300.0
12 Turks per batch : 3
13 Average test error : 51.6666666667%
14 Error variance : 3.33333333333%
15 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper
16 Testing on : upper
17 Total entries : 63.0
18 Turks per batch : 3
19 Average test error : 53.9682539683%
20 Error variance : 1.77777777778%
21 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower
22 Testing on : lower
23 Total entries : 135.0
24 Turks per batch : 3
25 Average test error : 37.037037037%
26 Error variance : 3.77777777778%
27 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits
28 Testing on : digits
29 Total entries : 102.0
30 Turks per batch : 3
31 Average test error : 50.9803921569%
32 Error variance : 1.33333333333%
33 """
2 34
3 import csv,numpy,re,decimal 35 import csv,numpy,re,decimal
4 36
5 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', 37 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/',
6 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', 38 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/',
10 #PATH = None 42 #PATH = None
11 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] 43 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ]
12 img_url = 'Input.image_url' 44 img_url = 'Input.image_url'
13 turks_per_batch = 3 45 turks_per_batch = 3
14 image_per_batch = 10 46 image_per_batch = 10
15 47 TYPE = None
16 def setup_association(): 48
49 def all_classes_assoc():
17 answer_assoc = {} 50 answer_assoc = {}
18 for i in range(0,10): 51 for i in range(0,10):
19 answer_assoc[str(i)]=i 52 answer_assoc[str(i)]=i
20 for i in range(10,36): 53 for i in range(10,36):
21 answer_assoc[chr(i+55)]=i 54 answer_assoc[chr(i+55)]=i
22 for i in range(36,62): 55 for i in range(36,62):
23 answer_assoc[chr(i+61)]=i 56 answer_assoc[chr(i+61)]=i
24 return answer_assoc 57 return answer_assoc
25 58
26 answer_assoc = setup_association() 59 def upper_classes_assoc():
27 60 answer_assoc = {}
28 def test_error(): 61 for i in range(10,36):
62 answer_assoc[chr(i+55)]=i
63 return answer_assoc
64
65 def lower_classes_assoc():
66 answer_assoc = {}
67 for i in range(36,62):
68 answer_assoc[chr(i+61)]=i
69 return answer_assoc
70
71 def digit_classes_assoc():
72 answer_assoc = {}
73 for i in range(0,10):
74 answer_assoc[str(i)]=i
75 return answer_assoc
76
77 def tsix_classes_assoc():
78 answer_assoc = {}
79 for i in range(10,36):
80 answer_assoc[chr(i+55)]=i
81 answer_assoc[chr(i+87)]=i
82 return answer_assoc
83
84 def upper_label_assoc(ulabel):
85 for i in range(len(ulabel)):
86 if ulabel[i] < 10 or ulabel[i] > 35 :
87 ulabel[i] = -1
88 return ulabel
89
90 def lower_label_assoc(ulabel):
91 for i in range(len(ulabel)):
92 if ulabel[i] < 36 or ulabel[i] > 61 :
93 ulabel[i] = -1
94 return ulabel
95
96 def tsix_label_assoc(ulabel):
97 for i in range(len(ulabel)):
98 if ulabel[i] > 35 and ulabel[i] < 62 :
99 ulabel[i] = ulabel[i] - 26
100 return ulabel
101
102 def digit_label_assoc(ulabel):
103 for i in range(len(ulabel)):
104 if ulabel[i] < 0 or ulabel[i] > 9 :
105 ulabel[i] = -1
106
107 return ulabel
108
109 def classes_answer(type):
110 if type == 'all':
111 return all_classes_assoc()
112 elif type == '36':
113 return tsix_classes_assoc()
114 elif type == 'lower':
115 return lower_classes_assoc()
116 elif type == 'upper':
117 return upper_classes_assoc()
118 elif type == 'digits':
119 return digit_classes_assoc()
120 else:
121 raise ('Inapropriate option for the type of classification :' + type)
122
123
124 def test_error(assoc_type=TYPE):
125 answer_assoc = classes_answer(assoc_type)
126
29 turks = [] 127 turks = []
30 reader = csv.DictReader(open(CVSFILE), delimiter=',') 128 reader = csv.DictReader(open(CVSFILE), delimiter=',')
31 entries = [ turk for turk in reader ] 129 entries = [ turk for turk in reader ]
32 130
33 errors = numpy.zeros((len(entries),)) 131 errors = numpy.zeros((len(entries),))
35 raise Exception('Wrong number of entries or turks_per_batch') 133 raise Exception('Wrong number of entries or turks_per_batch')
36 134
37 total_uniq_entries = len(entries) / turks_per_batch 135 total_uniq_entries = len(entries) / turks_per_batch
38 136
39 errors = numpy.zeros((len(entries),)) 137 errors = numpy.zeros((len(entries),))
138 num_examples = numpy.zeros((len(entries),))
40 error_means = numpy.zeros((total_uniq_entries,)) 139 error_means = numpy.zeros((total_uniq_entries,))
41 error_variances = numpy.zeros((total_uniq_entries,)) 140 error_variances = numpy.zeros((total_uniq_entries,))
42 141
43 PATH = DATASET_PATH[find_dataset(entries[0])]
44 for i in range(total_uniq_entries): 142 for i in range(total_uniq_entries):
45 for t in range(turks_per_batch): 143 for t in range(turks_per_batch):
46 errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH) 144 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type)
47 error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() 145 #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
48 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() 146 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
49 147
50 decimal.getcontext().prec = 3 148 percentage_error = 100. * errors.sum() / num_examples.sum()
51 print 'Total entries : ' + str(len(entries)) 149 print 'Testing on : ' + str(assoc_type)
150 print 'Total entries : ' + str(num_examples.sum())
52 print 'Turks per batch : ' + str(turks_per_batch) 151 print 'Turks per batch : ' + str(turks_per_batch)
53 print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%' 152 print 'Average test error : ' + str(percentage_error) +'%'
54 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' 153 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%'
55 154
56 155
57 def find_dataset(entry): 156 def find_dataset(entry):
58 file = parse_filename(entry[img_url]) 157 file = parse_filename(entry[img_url])
59 return file.split('_')[0] 158 return file.split('_')[0]
60 159
61 def get_error(entry,PATH): 160 def get_error(entry, type):
62 file = parse_filename(entry[img_url]) 161 answer_assoc = classes_answer(type)
63 f = open(PATH+file,'r') 162 labels = get_labels(entry,type)
64 labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
65 f.close()
66 test_error = 0 163 test_error = 0
164 cnt = 0
67 for i in range(len(answer_labels)): 165 for i in range(len(answer_labels)):
68 answer = entry[answer_labels[i]] 166 if labels[i] == -1:
69 if len(answer) != 0: 167 cnt+=1
168 else:
169 answer = entry[answer_labels[i]]
70 try: 170 try:
71 if answer_assoc[answer] != int(labels[i]): 171 if answer_assoc[answer] != labels[i]:
72 test_error+=1 172 test_error+=1
73 except: 173 except:
74 test_error+=1 174 test_error+=1
75 else: 175 return test_error,image_per_batch-cnt
76 test_error+=1 176
77 return test_error 177
78 178 def get_labels(entry,type):
179 file = parse_filename(entry[img_url])
180 path = DATASET_PATH[find_dataset(entry)]
181 f = open(path+file,'r')
182 str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
183 unrestricted_labels = [ int(element) for element in str_labels ]
184 if type == 'all':
185 return unrestricted_labels
186 elif type == '36':
187 return tsix_label_assoc(unrestricted_labels)
188 elif type == 'lower':
189 return lower_label_assoc(unrestricted_labels)
190 elif type == 'upper':
191 return upper_label_assoc(unrestricted_labels)
192 elif type == 'digits':
193 return digit_label_assoc(unrestricted_labels)
194 else:
195 raise ('Inapropriate option for the type of classification :' + str(type))
196
79 197
80 def parse_filename(string): 198 def parse_filename(string):
81 filename = string.split('/')[-1] 199 filename = string.split('/')[-1]
82 return filename.split('.')[0]+'.txt' 200 return filename.split('.')[0]+'.txt'
83 201
84 if __name__ =='__main__': 202 if __name__ =='__main__':
85 import sys 203 import sys
86 CVSFILE = sys.argv[1] 204 CVSFILE = sys.argv[1]
87 test_error() 205 test_error(sys.argv[2])