annotate deep/amt/amt.py @ 425:c06a3d9b5664

small syntax error
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Fri, 30 Apr 2010 16:24:35 -0400
parents 6478eef4f8aa
children 5777b5041ac9
rev   line source
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
1 # Script usage : python amt.py filname.cvs type
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
2 """
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
3 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
4 Testing on : all
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
5 Total entries : 300.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
6 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
7 Average test error : 45.3333333333%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
8 Error variance : 7.77777777778%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
9 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
10 Testing on : 36
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
11 Total entries : 300.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
12 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
13 Average test error : 51.6666666667%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
14 Error variance : 3.33333333333%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
15 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
16 Testing on : upper
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
17 Total entries : 63.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
18 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
19 Average test error : 53.9682539683%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
20 Error variance : 1.77777777778%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
21 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
22 Testing on : lower
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
23 Total entries : 135.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
24 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
25 Average test error : 37.037037037%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
26 Error variance : 3.77777777778%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
27 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
28 Testing on : digits
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
29 Total entries : 102.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
30 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
31 Average test error : 50.9803921569%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
32 Error variance : 1.33333333333%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
33 """
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
34
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
35 import csv,numpy,re,decimal
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
36
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
37 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/',
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
38 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/',
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
39 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' }
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
40
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
41 CVSFILE = None
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
42 #PATH = None
403
a11692910312 Undoing some unwanted changes
humel
parents: 402
diff changeset
43 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ]
a11692910312 Undoing some unwanted changes
humel
parents: 402
diff changeset
44 img_url = 'Input.image_url'
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
45 turks_per_batch = 3
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
46 image_per_batch = 10
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
47 TYPE = None
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
48
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
49 def all_classes_assoc():
401
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
50 answer_assoc = {}
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
51 for i in range(0,10):
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
52 answer_assoc[str(i)]=i
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
53 for i in range(10,36):
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
54 answer_assoc[chr(i+55)]=i
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
55 for i in range(36,62):
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
56 answer_assoc[chr(i+61)]=i
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
57 return answer_assoc
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
58
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
59 def upper_classes_assoc():
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
60 answer_assoc = {}
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
61 for i in range(10,36):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
62 answer_assoc[chr(i+55)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
63 return answer_assoc
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
64
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
65 def lower_classes_assoc():
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
66 answer_assoc = {}
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
67 for i in range(36,62):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
68 answer_assoc[chr(i+61)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
69 return answer_assoc
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
70
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
71 def digit_classes_assoc():
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
72 answer_assoc = {}
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
73 for i in range(0,10):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
74 answer_assoc[str(i)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
75 return answer_assoc
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
76
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
77 def tsix_classes_assoc():
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
78 answer_assoc = {}
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
79 for i in range(10,36):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
80 answer_assoc[chr(i+55)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
81 answer_assoc[chr(i+87)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
82 return answer_assoc
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
83
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
84 def upper_label_assoc(ulabel):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
85 for i in range(len(ulabel)):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
86 if ulabel[i] < 10 or ulabel[i] > 35 :
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
87 ulabel[i] = -1
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
88 return ulabel
401
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
89
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
90 def lower_label_assoc(ulabel):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
91 for i in range(len(ulabel)):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
92 if ulabel[i] < 36 or ulabel[i] > 61 :
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
93 ulabel[i] = -1
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
94 return ulabel
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
95
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
96 def tsix_label_assoc(ulabel):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
97 for i in range(len(ulabel)):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
98 if ulabel[i] > 35 and ulabel[i] < 62 :
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
99 ulabel[i] = ulabel[i] - 26
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
100 return ulabel
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
101
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
102 def digit_label_assoc(ulabel):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
103 for i in range(len(ulabel)):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
104 if ulabel[i] < 0 or ulabel[i] > 9 :
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
105 ulabel[i] = -1
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
106
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
107 return ulabel
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
108
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
109 def classes_answer(type):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
110 if type == 'all':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
111 return all_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
112 elif type == '36':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
113 return tsix_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
114 elif type == 'lower':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
115 return lower_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
116 elif type == 'upper':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
117 return upper_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
118 elif type == 'digits':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
119 return digit_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
120 else:
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
121 raise ('Inapropriate option for the type of classification :' + type)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
122
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
123
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
124 def test_error(assoc_type=TYPE):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
125 answer_assoc = classes_answer(assoc_type)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
126
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
127 turks = []
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
128 reader = csv.DictReader(open(CVSFILE), delimiter=',')
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
129 entries = [ turk for turk in reader ]
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
130
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
131 errors = numpy.zeros((len(entries),))
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
132 if len(entries) % turks_per_batch != 0 :
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
133 raise Exception('Wrong number of entries or turks_per_batch')
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
134
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
135 total_uniq_entries = len(entries) / turks_per_batch
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
136
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
137 errors = numpy.zeros((len(entries),))
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
138 num_examples = numpy.zeros((len(entries),))
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
139 error_means = numpy.zeros((total_uniq_entries,))
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
140 error_variances = numpy.zeros((total_uniq_entries,))
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
141
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
142 for i in range(total_uniq_entries):
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
143 for t in range(turks_per_batch):
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
144 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
145 #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
146 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
147
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
148 percentage_error = 100. * errors.sum() / num_examples.sum()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
149 print 'Testing on : ' + str(assoc_type)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
150 print 'Total entries : ' + str(num_examples.sum())
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
151 print 'Turks per batch : ' + str(turks_per_batch)
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
152 print 'Average test error : ' + str(percentage_error) +'%'
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
153 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%'
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
154
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
155
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
156 def find_dataset(entry):
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
157 file = parse_filename(entry[img_url])
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
158 return file.split('_')[0]
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
159
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
160 def get_error(entry, type):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
161 answer_assoc = classes_answer(type)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
162 labels = get_labels(entry,type)
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
163 test_error = 0
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
164 cnt = 0
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
165 for i in range(len(answer_labels)):
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
166 if labels[i] == -1:
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
167 cnt+=1
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
168 else:
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
169 answer = entry[answer_labels[i]]
401
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
170 try:
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
171 if answer_assoc[answer] != labels[i]:
401
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
172 test_error+=1
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
173 except:
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
174 test_error+=1
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
175 return test_error,image_per_batch-cnt
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
176
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
177
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
178 def get_labels(entry,type):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
179 file = parse_filename(entry[img_url])
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
180 path = DATASET_PATH[find_dataset(entry)]
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
181 f = open(path+file,'r')
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
182 str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
183 unrestricted_labels = [ int(element) for element in str_labels ]
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
184 if type == 'all':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
185 return unrestricted_labels
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
186 elif type == '36':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
187 return tsix_label_assoc(unrestricted_labels)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
188 elif type == 'lower':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
189 return lower_label_assoc(unrestricted_labels)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
190 elif type == 'upper':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
191 return upper_label_assoc(unrestricted_labels)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
192 elif type == 'digits':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
193 return digit_label_assoc(unrestricted_labels)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
194 else:
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
195 raise ('Inapropriate option for the type of classification :' + str(type))
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
196
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
197
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
198 def parse_filename(string):
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
199 filename = string.split('/')[-1]
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
200 return filename.split('.')[0]+'.txt'
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
201
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
202 if __name__ =='__main__':
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
203 import sys
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
204 CVSFILE = sys.argv[1]
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
205 test_error(sys.argv[2])