annotate deep/amt/amt.py @ 643:24d9819a810f

reviews aistats finales
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Thu, 24 Mar 2011 17:04:38 -0400
parents 7bdd412754ea
children
rev   line source
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
1 # Script usage : python amt.py filname.cvs type
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
2 """
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
3 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
4 Testing on : all
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
5 Total entries : 300.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
6 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
7 Average test error : 45.3333333333%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
8 Error variance : 7.77777777778%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
9 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
10 Testing on : 36
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
11 Total entries : 300.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
12 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
13 Average test error : 51.6666666667%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
14 Error variance : 3.33333333333%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
15 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
16 Testing on : upper
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
17 Total entries : 63.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
18 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
19 Average test error : 53.9682539683%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
20 Error variance : 1.77777777778%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
21 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
22 Testing on : lower
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
23 Total entries : 135.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
24 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
25 Average test error : 37.037037037%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
26 Error variance : 3.77777777778%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
27 [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
28 Testing on : digits
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
29 Total entries : 102.0
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
30 Turks per batch : 3
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
31 Average test error : 50.9803921569%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
32 Error variance : 1.33333333333%
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
33 """
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
34
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
35 import csv,numpy,re,decimal
449
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
36 from ift6266 import datasets
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
37 from pylearn.io import filetensor as ft
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
38
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
39 fnist = open('nist_train_class_freq.ft','r')
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
40 fp07 = open('p07_train_class_freq.ft','r')
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
41 fpnist = open('pnist_train_class_freq.ft','r')
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
42
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
43 nist_freq_table = ft.read(fnist)
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
44 p07_freq_table = ft.read(fp07)
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
45 pnist_freq_table = ft.read(fpnist)
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
46
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
47 fnist.close();fp07.close();fpnist.close()
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
48
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
49 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/',
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
50 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/',
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
51 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' }
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
52
449
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
53 freq_tables = { 'nist' : nist_freq_table,
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
54 'p07' : p07_freq_table,
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
55 'pnist': pnist_freq_table }
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
56
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
57
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
58 CVSFILE = None
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
59 #PATH = None
403
a11692910312 Undoing some unwanted changes
humel
parents: 402
diff changeset
60 answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ]
a11692910312 Undoing some unwanted changes
humel
parents: 402
diff changeset
61 img_url = 'Input.image_url'
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
62 turks_per_batch = 3
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
63 image_per_batch = 10
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
64 TYPE = None
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
65
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
66 def all_classes_assoc():
401
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
67 answer_assoc = {}
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
68 for i in range(0,10):
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
69 answer_assoc[str(i)]=i
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
70 for i in range(10,36):
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
71 answer_assoc[chr(i+55)]=i
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
72 for i in range(36,62):
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
73 answer_assoc[chr(i+61)]=i
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
74 return answer_assoc
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
75
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
76 def upper_classes_assoc():
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
77 answer_assoc = {}
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
78 for i in range(10,36):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
79 answer_assoc[chr(i+55)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
80 return answer_assoc
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
81
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
82 def lower_classes_assoc():
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
83 answer_assoc = {}
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
84 for i in range(36,62):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
85 answer_assoc[chr(i+61)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
86 return answer_assoc
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
87
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
88 def digit_classes_assoc():
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
89 answer_assoc = {}
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
90 for i in range(0,10):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
91 answer_assoc[str(i)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
92 return answer_assoc
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
93
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
94 def tsix_classes_assoc():
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
95 answer_assoc = {}
430
5777b5041ac9 fixed error computation for 36 classes
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 412
diff changeset
96 for i in range(0,10):
5777b5041ac9 fixed error computation for 36 classes
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 412
diff changeset
97 answer_assoc[str(i)]=i
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
98 for i in range(10,36):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
99 answer_assoc[chr(i+55)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
100 answer_assoc[chr(i+87)]=i
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
101 return answer_assoc
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
102
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
103 def upper_label_assoc(ulabel):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
104 for i in range(len(ulabel)):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
105 if ulabel[i] < 10 or ulabel[i] > 35 :
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
106 ulabel[i] = -1
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
107 return ulabel
401
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
108
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
109 def lower_label_assoc(ulabel):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
110 for i in range(len(ulabel)):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
111 if ulabel[i] < 36 or ulabel[i] > 61 :
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
112 ulabel[i] = -1
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
113 return ulabel
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
114
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
115 def tsix_label_assoc(ulabel):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
116 for i in range(len(ulabel)):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
117 if ulabel[i] > 35 and ulabel[i] < 62 :
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
118 ulabel[i] = ulabel[i] - 26
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
119 return ulabel
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
120
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
121 def digit_label_assoc(ulabel):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
122 for i in range(len(ulabel)):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
123 if ulabel[i] < 0 or ulabel[i] > 9 :
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
124 ulabel[i] = -1
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
125
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
126 return ulabel
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
127
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
128 def classes_answer(type):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
129 if type == 'all':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
130 return all_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
131 elif type == '36':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
132 return tsix_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
133 elif type == 'lower':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
134 return lower_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
135 elif type == 'upper':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
136 return upper_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
137 elif type == 'digits':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
138 return digit_classes_assoc()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
139 else:
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
140 raise ('Inapropriate option for the type of classification :' + type)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
141
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
142
449
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
143 def test_error(assoc_type=TYPE,consensus=True):
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
144 answer_assoc = classes_answer(assoc_type)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
145
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
146 turks = []
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
147 reader = csv.DictReader(open(CVSFILE), delimiter=',')
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
148 entries = [ turk for turk in reader ]
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
149
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
150 errors = numpy.zeros((len(entries),))
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
151 if len(entries) % turks_per_batch != 0 :
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
152 raise Exception('Wrong number of entries or turks_per_batch')
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
153
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
154 total_uniq_entries = len(entries) / turks_per_batch
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
155
449
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
156
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
157 error_variances = numpy.zeros((total_uniq_entries,))
449
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
158
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
159 if consensus:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
160 errors = numpy.zeros((total_uniq_entries,))
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
161 num_examples = numpy.zeros((total_uniq_entries,))
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
162 for i in range(total_uniq_entries):
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
163 errors[i],num_examples[i] = get_turk_consensus_error(entries[i*turks_per_batch:(i+1)*turks_per_batch],assoc_type)
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
164 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
165 else:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
166 errors = numpy.zeros((len(entries),))
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
167 num_examples = numpy.zeros((len(entries),))
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
168 for i in range(total_uniq_entries):
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
169 for t in range(turks_per_batch):
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
170 errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type)
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
171 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
172
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
173 percentage_error = 100. * errors.sum() / num_examples.sum()
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
174 print 'Testing on : ' + str(assoc_type)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
175 print 'Total entries : ' + str(num_examples.sum())
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
176 print 'Turks per batch : ' + str(turks_per_batch)
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
177 print 'Average test error : ' + str(percentage_error) +'%'
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
178 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%'
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
179
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
180
402
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
181 def find_dataset(entry):
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
182 file = parse_filename(entry[img_url])
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
183 return file.split('_')[0]
83413ac10913 Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
humel
parents: 401
diff changeset
184
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
185 def get_error(entry, type):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
186 answer_assoc = classes_answer(type)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
187 labels = get_labels(entry,type)
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
188 test_error = 0
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
189 cnt = 0
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
190 for i in range(len(answer_labels)):
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
191 if labels[i] == -1:
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
192 cnt+=1
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
193 else:
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
194 answer = entry[answer_labels[i]]
401
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
195 try:
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
196 if answer_assoc[answer] != labels[i]:
401
86d5e583e278 Fixed class number bug
humel
parents: 399
diff changeset
197 test_error+=1
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
198 except:
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
199 test_error+=1
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
200 return test_error,image_per_batch-cnt
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
201
449
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
202 def get_turk_consensus_error(entries, type):
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
203 answer_assoc = classes_answer(type)
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
204 labels = get_labels(entries[0],type)
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
205 test_error = 0
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
206 cnt = 0
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
207 answer= []
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
208 freq_t = freq_tables[find_dataset(entries[0])]
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
209 for i in range(len(answer_labels)):
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
210 if labels[i] == -1:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
211 cnt+=1
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
212 else:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
213 answers = [ entry[answer_labels[i]] for entry in entries ]
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
214 if answers[0] != answers[1] and answers[1] != answers[2] and answers[0] != answers[2]:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
215 m = max([ freq_t[answer_assoc[answer]] for answer in answers])
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
216 for answer in answers:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
217 if freq_t[answer_assoc[answer]] == m :
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
218 a = answer
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
219 else:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
220 for answer in answers:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
221 if answers.count(answer) > 1 :
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
222 a =answer
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
223 try:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
224 if answer_assoc[answer] != labels[i]:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
225 test_error+=1
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
226 except:
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
227 test_error+=1
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
228 return test_error,image_per_batch-cnt
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
229 def frequency_table():
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
230 filenames = ['nist_train_class_freq.ft','p07_train_class_freq.ft','pnist_train_class_freq.ft']
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
231 iterators = [datasets.nist_all(),datasets.nist_P07(),datasets.PNIST07()]
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
232 for dataset,filename in zip(iterators,filenames):
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
233 freq_table = numpy.zeros(62)
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
234 for x,y in dataset.train(1):
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
235 freq_table[int(y)]+=1
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
236 f = open(filename,'w')
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
237 ft.write(f,freq_table)
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
238 f.close()
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
239
412
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
240 def get_labels(entry,type):
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
241 file = parse_filename(entry[img_url])
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
242 path = DATASET_PATH[find_dataset(entry)]
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
243 f = open(path+file,'r')
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
244 str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
245 unrestricted_labels = [ int(element) for element in str_labels ]
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
246 if type == 'all':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
247 return unrestricted_labels
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
248 elif type == '36':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
249 return tsix_label_assoc(unrestricted_labels)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
250 elif type == 'lower':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
251 return lower_label_assoc(unrestricted_labels)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
252 elif type == 'upper':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
253 return upper_label_assoc(unrestricted_labels)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
254 elif type == 'digits':
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
255 return digit_label_assoc(unrestricted_labels)
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
256 else:
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
257 raise ('Inapropriate option for the type of classification :' + str(type))
6478eef4f8aa Added support for calculating the test error over different set of classes (lower,upper,digits,all,36)
humel
parents: 403
diff changeset
258
399
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
259
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
260 def parse_filename(string):
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
261 filename = string.split('/')[-1]
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
262 return filename.split('.')[0]+'.txt'
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
263
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
264 if __name__ =='__main__':
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
265 import sys
99905d9bc9dd Initial commit for calculating the test error of the AMT classifier
humel
parents:
diff changeset
266 CVSFILE = sys.argv[1]
449
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
267 test_error(sys.argv[2],int(sys.argv[3]))
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
268 #frequency_table()
7bdd412754ea Added support to calculate the human consensus error
humel
parents: 430
diff changeset
269