Mercurial > ift6266
annotate deep/amt/amt.py @ 443:89a49dae6cf3
merge
author | Xavier Glorot <glorotxa@iro.umontreal.ca> |
---|---|
date | Mon, 03 May 2010 18:38:58 -0400 |
parents | 5777b5041ac9 |
children | 7bdd412754ea |
rev | line source |
---|---|
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
1 # Script usage : python amt.py filname.cvs type |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
2 """ |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
3 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
4 Testing on : all |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
5 Total entries : 300.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
6 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
7 Average test error : 45.3333333333% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
8 Error variance : 7.77777777778% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
9 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
10 Testing on : 36 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
11 Total entries : 300.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
12 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
13 Average test error : 51.6666666667% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
14 Error variance : 3.33333333333% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
15 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
16 Testing on : upper |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
17 Total entries : 63.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
18 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
19 Average test error : 53.9682539683% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
20 Error variance : 1.77777777778% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
21 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
22 Testing on : lower |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
23 Total entries : 135.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
24 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
25 Average test error : 37.037037037% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
26 Error variance : 3.77777777778% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
27 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
28 Testing on : digits |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
29 Total entries : 102.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
30 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
31 Average test error : 50.9803921569% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
32 Error variance : 1.33333333333% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
33 """ |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
34 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
35 import csv,numpy,re,decimal |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
36 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
37 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
38 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
39 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
40 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
41 CVSFILE = None |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
42 #PATH = None |
403 | 43 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] |
44 img_url = 'Input.image_url' | |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
45 turks_per_batch = 3 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
46 image_per_batch = 10 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
47 TYPE = None |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
48 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
49 def all_classes_assoc(): |
401 | 50 answer_assoc = {} |
51 for i in range(0,10): | |
52 answer_assoc[str(i)]=i | |
53 for i in range(10,36): | |
54 answer_assoc[chr(i+55)]=i | |
55 for i in range(36,62): | |
56 answer_assoc[chr(i+61)]=i | |
57 return answer_assoc | |
58 | |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
59 def upper_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
60 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
61 for i in range(10,36): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
62 answer_assoc[chr(i+55)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
63 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
64 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
65 def lower_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
66 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
67 for i in range(36,62): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
68 answer_assoc[chr(i+61)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
69 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
70 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
71 def digit_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
72 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
73 for i in range(0,10): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
74 answer_assoc[str(i)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
75 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
76 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
77 def tsix_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
78 answer_assoc = {} |
430
5777b5041ac9
fixed error computation for 36 classes
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
412
diff
changeset
|
79 for i in range(0,10): |
5777b5041ac9
fixed error computation for 36 classes
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
412
diff
changeset
|
80 answer_assoc[str(i)]=i |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
81 for i in range(10,36): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
82 answer_assoc[chr(i+55)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
83 answer_assoc[chr(i+87)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
84 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
85 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
86 def upper_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
87 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
88 if ulabel[i] < 10 or ulabel[i] > 35 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
89 ulabel[i] = -1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
90 return ulabel |
401 | 91 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
92 def lower_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
93 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
94 if ulabel[i] < 36 or ulabel[i] > 61 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
95 ulabel[i] = -1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
96 return ulabel |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
97 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
98 def tsix_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
99 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
100 if ulabel[i] > 35 and ulabel[i] < 62 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
101 ulabel[i] = ulabel[i] - 26 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
102 return ulabel |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
103 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
104 def digit_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
105 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
106 if ulabel[i] < 0 or ulabel[i] > 9 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
107 ulabel[i] = -1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
108 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
109 return ulabel |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
110 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
111 def classes_answer(type): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
112 if type == 'all': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
113 return all_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
114 elif type == '36': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
115 return tsix_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
116 elif type == 'lower': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
117 return lower_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
118 elif type == 'upper': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
119 return upper_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
120 elif type == 'digits': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
121 return digit_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
122 else: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
123 raise ('Inapropriate option for the type of classification :' + type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
124 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
125 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
126 def test_error(assoc_type=TYPE): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
127 answer_assoc = classes_answer(assoc_type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
128 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
129 turks = [] |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
130 reader = csv.DictReader(open(CVSFILE), delimiter=',') |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
131 entries = [ turk for turk in reader ] |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
132 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
133 errors = numpy.zeros((len(entries),)) |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
134 if len(entries) % turks_per_batch != 0 : |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
135 raise Exception('Wrong number of entries or turks_per_batch') |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
136 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
137 total_uniq_entries = len(entries) / turks_per_batch |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
138 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
139 errors = numpy.zeros((len(entries),)) |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
140 num_examples = numpy.zeros((len(entries),)) |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
141 error_means = numpy.zeros((total_uniq_entries,)) |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
142 error_variances = numpy.zeros((total_uniq_entries,)) |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
143 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
144 for i in range(total_uniq_entries): |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
145 for t in range(turks_per_batch): |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
146 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
147 #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
148 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
149 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
150 percentage_error = 100. * errors.sum() / num_examples.sum() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
151 print 'Testing on : ' + str(assoc_type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
152 print 'Total entries : ' + str(num_examples.sum()) |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
153 print 'Turks per batch : ' + str(turks_per_batch) |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
154 print 'Average test error : ' + str(percentage_error) +'%' |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
155 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
156 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
157 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
158 def find_dataset(entry): |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
159 file = parse_filename(entry[img_url]) |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
160 return file.split('_')[0] |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
161 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
162 def get_error(entry, type): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
163 answer_assoc = classes_answer(type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
164 labels = get_labels(entry,type) |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
165 test_error = 0 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
166 cnt = 0 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
167 for i in range(len(answer_labels)): |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
168 if labels[i] == -1: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
169 cnt+=1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
170 else: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
171 answer = entry[answer_labels[i]] |
401 | 172 try: |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
173 if answer_assoc[answer] != labels[i]: |
401 | 174 test_error+=1 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
175 except: |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
176 test_error+=1 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
177 return test_error,image_per_batch-cnt |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
178 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
179 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
180 def get_labels(entry,type): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
181 file = parse_filename(entry[img_url]) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
182 path = DATASET_PATH[find_dataset(entry)] |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
183 f = open(path+file,'r') |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
184 str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
185 unrestricted_labels = [ int(element) for element in str_labels ] |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
186 if type == 'all': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
187 return unrestricted_labels |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
188 elif type == '36': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
189 return tsix_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
190 elif type == 'lower': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
191 return lower_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
192 elif type == 'upper': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
193 return upper_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
194 elif type == 'digits': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
195 return digit_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
196 else: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
197 raise ('Inapropriate option for the type of classification :' + str(type)) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
198 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
199 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
200 def parse_filename(string): |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
201 filename = string.split('/')[-1] |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
202 return filename.split('.')[0]+'.txt' |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
203 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
204 if __name__ =='__main__': |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
205 import sys |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
206 CVSFILE = sys.argv[1] |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
207 test_error(sys.argv[2]) |