Mercurial > ift6266
annotate deep/amt/amt.py @ 419:c91d7b67fa41
Correction d'une petite erreur dans le nom des fichiers de parametres de pretrain
author | SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca> |
---|---|
date | Fri, 30 Apr 2010 14:48:08 -0400 |
parents | 6478eef4f8aa |
children | 5777b5041ac9 |
rev | line source |
---|---|
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
1 # Script usage : python amt.py filname.cvs type |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
2 """ |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
3 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
4 Testing on : all |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
5 Total entries : 300.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
6 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
7 Average test error : 45.3333333333% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
8 Error variance : 7.77777777778% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
9 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
10 Testing on : 36 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
11 Total entries : 300.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
12 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
13 Average test error : 51.6666666667% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
14 Error variance : 3.33333333333% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
15 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
16 Testing on : upper |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
17 Total entries : 63.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
18 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
19 Average test error : 53.9682539683% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
20 Error variance : 1.77777777778% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
21 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
22 Testing on : lower |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
23 Total entries : 135.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
24 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
25 Average test error : 37.037037037% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
26 Error variance : 3.77777777778% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
27 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
28 Testing on : digits |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
29 Total entries : 102.0 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
30 Turks per batch : 3 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
31 Average test error : 50.9803921569% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
32 Error variance : 1.33333333333% |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
33 """ |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
34 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
35 import csv,numpy,re,decimal |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
36 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
37 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
38 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
39 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
40 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
41 CVSFILE = None |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
42 #PATH = None |
403 | 43 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] |
44 img_url = 'Input.image_url' | |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
45 turks_per_batch = 3 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
46 image_per_batch = 10 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
47 TYPE = None |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
48 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
49 def all_classes_assoc(): |
401 | 50 answer_assoc = {} |
51 for i in range(0,10): | |
52 answer_assoc[str(i)]=i | |
53 for i in range(10,36): | |
54 answer_assoc[chr(i+55)]=i | |
55 for i in range(36,62): | |
56 answer_assoc[chr(i+61)]=i | |
57 return answer_assoc | |
58 | |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
59 def upper_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
60 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
61 for i in range(10,36): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
62 answer_assoc[chr(i+55)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
63 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
64 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
65 def lower_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
66 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
67 for i in range(36,62): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
68 answer_assoc[chr(i+61)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
69 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
70 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
71 def digit_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
72 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
73 for i in range(0,10): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
74 answer_assoc[str(i)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
75 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
76 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
77 def tsix_classes_assoc(): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
78 answer_assoc = {} |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
79 for i in range(10,36): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
80 answer_assoc[chr(i+55)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
81 answer_assoc[chr(i+87)]=i |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
82 return answer_assoc |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
83 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
84 def upper_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
85 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
86 if ulabel[i] < 10 or ulabel[i] > 35 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
87 ulabel[i] = -1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
88 return ulabel |
401 | 89 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
90 def lower_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
91 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
92 if ulabel[i] < 36 or ulabel[i] > 61 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
93 ulabel[i] = -1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
94 return ulabel |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
95 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
96 def tsix_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
97 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
98 if ulabel[i] > 35 and ulabel[i] < 62 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
99 ulabel[i] = ulabel[i] - 26 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
100 return ulabel |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
101 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
102 def digit_label_assoc(ulabel): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
103 for i in range(len(ulabel)): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
104 if ulabel[i] < 0 or ulabel[i] > 9 : |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
105 ulabel[i] = -1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
106 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
107 return ulabel |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
108 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
109 def classes_answer(type): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
110 if type == 'all': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
111 return all_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
112 elif type == '36': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
113 return tsix_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
114 elif type == 'lower': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
115 return lower_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
116 elif type == 'upper': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
117 return upper_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
118 elif type == 'digits': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
119 return digit_classes_assoc() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
120 else: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
121 raise ('Inapropriate option for the type of classification :' + type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
122 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
123 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
124 def test_error(assoc_type=TYPE): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
125 answer_assoc = classes_answer(assoc_type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
126 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
127 turks = [] |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
128 reader = csv.DictReader(open(CVSFILE), delimiter=',') |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
129 entries = [ turk for turk in reader ] |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
130 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
131 errors = numpy.zeros((len(entries),)) |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
132 if len(entries) % turks_per_batch != 0 : |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
133 raise Exception('Wrong number of entries or turks_per_batch') |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
134 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
135 total_uniq_entries = len(entries) / turks_per_batch |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
136 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
137 errors = numpy.zeros((len(entries),)) |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
138 num_examples = numpy.zeros((len(entries),)) |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
139 error_means = numpy.zeros((total_uniq_entries,)) |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
140 error_variances = numpy.zeros((total_uniq_entries,)) |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
141 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
142 for i in range(total_uniq_entries): |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
143 for t in range(turks_per_batch): |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
144 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
145 #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
146 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
147 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
148 percentage_error = 100. * errors.sum() / num_examples.sum() |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
149 print 'Testing on : ' + str(assoc_type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
150 print 'Total entries : ' + str(num_examples.sum()) |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
151 print 'Turks per batch : ' + str(turks_per_batch) |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
152 print 'Average test error : ' + str(percentage_error) +'%' |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
153 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
154 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
155 |
402
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
156 def find_dataset(entry): |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
157 file = parse_filename(entry[img_url]) |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
158 return file.split('_')[0] |
83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents:
401
diff
changeset
|
159 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
160 def get_error(entry, type): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
161 answer_assoc = classes_answer(type) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
162 labels = get_labels(entry,type) |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
163 test_error = 0 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
164 cnt = 0 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
165 for i in range(len(answer_labels)): |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
166 if labels[i] == -1: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
167 cnt+=1 |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
168 else: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
169 answer = entry[answer_labels[i]] |
401 | 170 try: |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
171 if answer_assoc[answer] != labels[i]: |
401 | 172 test_error+=1 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
173 except: |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
174 test_error+=1 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
175 return test_error,image_per_batch-cnt |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
176 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
177 |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
178 def get_labels(entry,type): |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
179 file = parse_filename(entry[img_url]) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
180 path = DATASET_PATH[find_dataset(entry)] |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
181 f = open(path+file,'r') |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
182 str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
183 unrestricted_labels = [ int(element) for element in str_labels ] |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
184 if type == 'all': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
185 return unrestricted_labels |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
186 elif type == '36': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
187 return tsix_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
188 elif type == 'lower': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
189 return lower_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
190 elif type == 'upper': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
191 return upper_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
192 elif type == 'digits': |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
193 return digit_label_assoc(unrestricted_labels) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
194 else: |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
195 raise ('Inapropriate option for the type of classification :' + str(type)) |
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
196 |
399
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
197 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
198 def parse_filename(string): |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
199 filename = string.split('/')[-1] |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
200 return filename.split('.')[0]+'.txt' |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
201 |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
202 if __name__ =='__main__': |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
203 import sys |
99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff
changeset
|
204 CVSFILE = sys.argv[1] |
412
6478eef4f8aa
Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents:
403
diff
changeset
|
205 test_error(sys.argv[2]) |