Mercurial > ift6266
annotate deep/amt/amt.py @ 583:ae77edb9df67
DIRO techreport, sent to arXiv
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Sat, 18 Sep 2010 16:44:46 -0400 |
parents | 7bdd412754ea |
children |
rev | line source |
---|---|
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
1 # Script usage : python amt.py filname.cvs type |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
2 """ |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
3 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
4 Testing on : all |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
5 Total entries : 300.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
6 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
7 Average test error : 45.3333333333% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
8 Error variance : 7.77777777778% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
9 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
10 Testing on : 36 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
11 Total entries : 300.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
12 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
13 Average test error : 51.6666666667% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
14 Error variance : 3.33333333333% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
15 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
16 Testing on : upper |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
17 Total entries : 63.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
18 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
19 Average test error : 53.9682539683% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
20 Error variance : 1.77777777778% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
21 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
22 Testing on : lower |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
23 Total entries : 135.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
24 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
25 Average test error : 37.037037037% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
26 Error variance : 3.77777777778% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
27 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
28 Testing on : digits |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
29 Total entries : 102.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
30 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
31 Average test error : 50.9803921569% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
32 Error variance : 1.33333333333% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
33 """ |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
34 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
35 import csv,numpy,re,decimal |
449 | 36 from ift6266 import datasets |
37 from pylearn.io import filetensor as ft | |
38 | |
39 fnist = open('nist_train_class_freq.ft','r') | |
40 fp07 = open('p07_train_class_freq.ft','r') | |
41 fpnist = open('pnist_train_class_freq.ft','r') | |
42 | |
43 nist_freq_table = ft.read(fnist) | |
44 p07_freq_table = ft.read(fp07) | |
45 pnist_freq_table = ft.read(fpnist) | |
46 | |
47 fnist.close();fp07.close();fpnist.close() | |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
48 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
49 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
50 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
51 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
52 |
449 | 53 freq_tables = { 'nist' : nist_freq_table, |
54 'p07' : p07_freq_table, | |
55 'pnist': pnist_freq_table } | |
56 | |
57 | |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
58 CVSFILE = None |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
59 #PATH = None |
403 | 60 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] |
61 img_url = 'Input.image_url' | |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
62 turks_per_batch = 3 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
63 image_per_batch = 10 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
64 TYPE = None |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
65 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
66 def all_classes_assoc(): |
401 | 67 answer_assoc = {} |
68 for i in range(0,10): | |
69 answer_assoc[str(i)]=i | |
70 for i in range(10,36): | |
71 answer_assoc[chr(i+55)]=i | |
72 for i in range(36,62): | |
73 answer_assoc[chr(i+61)]=i | |
74 return answer_assoc | |
75 | |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
76 def upper_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
77 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
78 for i in range(10,36): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
79 answer_assoc[chr(i+55)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
80 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
81 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
82 def lower_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
83 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
84 for i in range(36,62): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
85 answer_assoc[chr(i+61)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
86 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
87 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
88 def digit_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
89 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
90 for i in range(0,10): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
91 answer_assoc[str(i)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
92 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
93 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
94 def tsix_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
95 answer_assoc = {} |
430
5777b5041ac9
fixed error computation for 36 classes
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
412
diff
changeset
|
96 for i in range(0,10): |
5777b5041ac9
fixed error computation for 36 classes
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
412
diff
changeset
|
97 answer_assoc[str(i)]=i |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
98 for i in range(10,36): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
99 answer_assoc[chr(i+55)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
100 answer_assoc[chr(i+87)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
101 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
102 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
103 def upper_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
104 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
105 if ulabel[i] < 10 or ulabel[i] > 35 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
106 ulabel[i] = -1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
107 return ulabel |
401 | 108 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
109 def lower_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
110 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
111 if ulabel[i] < 36 or ulabel[i] > 61 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
112 ulabel[i] = -1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
113 return ulabel |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
114 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
115 def tsix_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
116 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
117 if ulabel[i] > 35 and ulabel[i] < 62 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
118 ulabel[i] = ulabel[i] - 26 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
119 return ulabel |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
120 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
121 def digit_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
122 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
123 if ulabel[i] < 0 or ulabel[i] > 9 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
124 ulabel[i] = -1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
125 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
126 return ulabel |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
127 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
128 def classes_answer(type): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
129 if type == 'all': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
130 return all_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
131 elif type == '36': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
132 return tsix_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
133 elif type == 'lower': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
134 return lower_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
135 elif type == 'upper': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
136 return upper_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
137 elif type == 'digits': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
138 return digit_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
139 else: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
140 raise ('Inapropriate option for the type of classification :' + type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
141 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
142 |
449 | 143 def test_error(assoc_type=TYPE,consensus=True): |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
144 answer_assoc = classes_answer(assoc_type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
145 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
146 turks = [] |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
147 reader = csv.DictReader(open(CVSFILE), delimiter=',') |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
148 entries = [ turk for turk in reader ] |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
149 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
150 errors = numpy.zeros((len(entries),)) |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
151 if len(entries) % turks_per_batch != 0 : |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
152 raise Exception('Wrong number of entries or turks_per_batch') |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
153 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
154 total_uniq_entries = len(entries) / turks_per_batch |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
155 |
449 | 156 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
157 error_variances = numpy.zeros((total_uniq_entries,)) |
449 | 158 |
159 if consensus: | |
160 errors = numpy.zeros((total_uniq_entries,)) | |
161 num_examples = numpy.zeros((total_uniq_entries,)) | |
162 for i in range(total_uniq_entries): | |
163 errors[i],num_examples[i] = get_turk_consensus_error(entries[i*turks_per_batch:(i+1)*turks_per_batch],assoc_type) | |
164 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() | |
165 else: | |
166 errors = numpy.zeros((len(entries),)) | |
167 num_examples = numpy.zeros((len(entries),)) | |
168 for i in range(total_uniq_entries): | |
169 for t in range(turks_per_batch): | |
170 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) | |
171 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() | |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
172 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
173 percentage_error = 100. * errors.sum() / num_examples.sum() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
174 print 'Testing on : ' + str(assoc_type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
175 print 'Total entries : ' + str(num_examples.sum()) |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
176 print 'Turks per batch : ' + str(turks_per_batch) |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
177 print 'Average test error : ' + str(percentage_error) +'%' |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
178 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
179 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
180 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
181 def find_dataset(entry): |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
182 file = parse_filename(entry[img_url]) |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
183 return file.split('_')[0] |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
184 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
185 def get_error(entry, type): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
186 answer_assoc = classes_answer(type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
187 labels = get_labels(entry,type) |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
188 test_error = 0 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
189 cnt = 0 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
190 for i in range(len(answer_labels)): |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
191 if labels[i] == -1: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
192 cnt+=1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
193 else: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
194 answer = entry[answer_labels[i]] |
401 | 195 try: |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
196 if answer_assoc[answer] != labels[i]: |
401 | 197 test_error+=1 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
198 except: |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
199 test_error+=1 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
200 return test_error,image_per_batch-cnt |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
201 |
449 | 202 def get_turk_consensus_error(entries, type): |
203 answer_assoc = classes_answer(type) | |
204 labels = get_labels(entries[0],type) | |
205 test_error = 0 | |
206 cnt = 0 | |
207 answer= [] | |
208 freq_t = freq_tables[find_dataset(entries[0])] | |
209 for i in range(len(answer_labels)): | |
210 if labels[i] == -1: | |
211 cnt+=1 | |
212 else: | |
213 answers = [ entry[answer_labels[i]] for entry in entries ] | |
214 if answers[0] != answers[1] and answers[1] != answers[2] and answers[0] != answers[2]: | |
215 m = max([ freq_t[answer_assoc[answer]] for answer in answers]) | |
216 for answer in answers: | |
217 if freq_t[answer_assoc[answer]] == m : | |
218 a = answer | |
219 else: | |
220 for answer in answers: | |
221 if answers.count(answer) > 1 : | |
222 a =answer | |
223 try: | |
224 if answer_assoc[answer] != labels[i]: | |
225 test_error+=1 | |
226 except: | |
227 test_error+=1 | |
228 return test_error,image_per_batch-cnt | |
229 def frequency_table(): | |
230 filenames = ['nist_train_class_freq.ft','p07_train_class_freq.ft','pnist_train_class_freq.ft'] | |
231 iterators = [datasets.nist_all(),datasets.nist_P07(),datasets.PNIST07()] | |
232 for dataset,filename in zip(iterators,filenames): | |
233 freq_table = numpy.zeros(62) | |
234 for x,y in dataset.train(1): | |
235 freq_table[int(y)]+=1 | |
236 f = open(filename,'w') | |
237 ft.write(f,freq_table) | |
238 f.close() | |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
239 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
240 def get_labels(entry,type): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
241 file = parse_filename(entry[img_url]) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
242 path = DATASET_PATH[find_dataset(entry)] |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
243 f = open(path+file,'r') |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
244 str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
245 unrestricted_labels = [ int(element) for element in str_labels ] |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
246 if type == 'all': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
247 return unrestricted_labels |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
248 elif type == '36': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
249 return tsix_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
250 elif type == 'lower': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
251 return lower_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
252 elif type == 'upper': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
253 return upper_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
254 elif type == 'digits': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
255 return digit_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
256 else: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
257 raise ('Inapropriate option for the type of classification :' + str(type)) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
258 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
259 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
260 def parse_filename(string): |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
261 filename = string.split('/')[-1] |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
262 return filename.split('.')[0]+'.txt' |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
263 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
264 if __name__ =='__main__': |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
265 import sys |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
266 CVSFILE = sys.argv[1] |
449 | 267 test_error(sys.argv[2],int(sys.argv[3])) |
268 #frequency_table() | |
269 |