Mercurial > ift6266
comparison deep/stacked_dae/v_sylvain/nist_apriori_error.py @ 418:fb028b37ce92
Sert a calculer l'erreur sur les differentes classes de NIST.
author | SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca> |
---|---|
date | Fri, 30 Apr 2010 14:47:25 -0400 |
parents | |
children | 5ca2936f2062 |
comparison
equal
deleted
inserted
replaced
417:0282882aa91f | 418:fb028b37ce92 |
---|---|
1 __docformat__ = 'restructedtext en' | |
2 | |
3 import pdb | |
4 import numpy | |
5 from numpy import array | |
6 import time | |
7 import datetime | |
8 import pylearn | |
9 import copy | |
10 import sys | |
11 import os | |
12 import os.path | |
13 from pylearn.io import filetensor as ft | |
14 from jobman import DD | |
15 from ift6266 import datasets | |
16 import cPickle | |
17 from copy import copy | |
18 import math | |
19 | |
20 from config import * | |
21 | |
22 data_path = '/data/lisa/data/nist/by_class/' | |
23 test_data = 'all/all_train_data.ft' | |
24 test_labels = 'all/all_train_labels.ft' | |
25 state = DD(DEFAULT_HP_NIST) | |
26 | |
27 #sda_model -> path for the parameters file | |
28 #dataset -> the dataset we use for the test | |
29 #part -> 0=train, 1=valid, 2=test | |
30 #type -> non-linearity type 0=sigmoid, 1=tanh | |
31 def test_data(sda_model,dataset,part=2,type=0): | |
32 | |
33 | |
34 f = open(sda_model) | |
35 parameters_pre=cPickle.load(f) | |
36 f.close() | |
37 | |
38 W1 = array(copy(parameters_pre[0])) | |
39 #print 'W1: ' + str(W1.shape) | |
40 b1 = array(copy(parameters_pre[1])) | |
41 #print 'b1: ' + str(b1.shape) | |
42 W2 = array(copy(parameters_pre[2])) | |
43 #print 'W2: ' + str(W2.shape) | |
44 b2 = array(copy(parameters_pre[3])) | |
45 #print 'b2: ' + str(b2.shape) | |
46 W3 = array(copy(parameters_pre[4])) | |
47 #print 'W3: ' + str(W3.shape) | |
48 b3 = array(copy(parameters_pre[5])) | |
49 #print 'b3: ' + str(b3.shape) | |
50 if state['num_hidden_layers'] == 4: | |
51 W4 = array(copy(parameters_pre[6])) | |
52 b4 = array(copy(parameters_pre[7])) | |
53 Wo = array(copy(parameters_pre[8])) | |
54 bo = array(copy(parameters_pre[9])) | |
55 elif state['num_hidden_layers'] == 3: | |
56 Wo = array(copy(parameters_pre[6])) | |
57 #print 'Wo: ' + str(Wo.shape) | |
58 bo = array(copy(parameters_pre[7])) | |
59 #print 'bo: ' + str(bo.shape) | |
60 W4=None | |
61 b4=None | |
62 else: | |
63 print('Number of layers not implemented yet, please do it') | |
64 | |
65 | |
66 total_error_count=0 | |
67 total_exemple_count=0 | |
68 if part == 0: | |
69 iter = dataset.train(1) | |
70 if part == 1: | |
71 iter = dataset.valid(1) | |
72 if part == 2: | |
73 iter = dataset.test(1) | |
74 for x,y in iter: | |
75 total_exemple_count = total_exemple_count +1 | |
76 if type == 1: | |
77 #get output for layer 1 | |
78 out1=(numpy.tanh(numpy.dot(x,W1) + b1)+1.0)/2.0 | |
79 #get output for layer 2 | |
80 out2=(numpy.tanh(numpy.dot(out1,W2) + b2)+1.0)/2.0 | |
81 #get output for layer 3 | |
82 out3=(numpy.tanh(numpy.dot(out2,W3) + b3)+1.0)/2.0 | |
83 #if there is a fourth layer | |
84 if state['num_hidden_layers'] == 4: | |
85 outf = (numpy.tanh(numpy.dot(out3,W4) + b4)+1.0)/2.0 | |
86 else: | |
87 outf = array(out3) | |
88 else: | |
89 #get output for layer 1 | |
90 out1=1.0/(1.0+numpy.exp(-(numpy.dot(x,W1)+b1))) | |
91 #get output for layer 2 | |
92 out2 = 1.0/(1.0+numpy.exp(-(numpy.dot(out1,W2)+b2))) | |
93 #get output for layer 3 | |
94 out3 = 1.0/(1.0+numpy.exp(-(numpy.dot(out2,W3)+b3))) | |
95 #if there is a fourth layer | |
96 if state['num_hidden_layers'] == 4: | |
97 outf = 1.0/(1.0+numpy.exp(-(numpy.dot(out3,W4)+b4))) | |
98 else: | |
99 outf = out3 | |
100 | |
101 out_act = numpy.dot(outf,Wo)+bo | |
102 | |
103 #add non linear function for output activation (softmax) | |
104 #We can also use sigmoid and results will be the same | |
105 out = numpy.zeros(len(out_act[0]),float) | |
106 a1_exp = numpy.exp(out_act) | |
107 sum_a1=numpy.sum(a1_exp) | |
108 out=a1_exp/sum_a1 | |
109 ## for i in xrange(len(out_act[0])): | |
110 ## out[i]=sigmoid(array(out_act[0,i])) | |
111 | |
112 #get grouped based error | |
113 #with a priori | |
114 if(y>9 and y<35): | |
115 predicted_class=numpy.argmax(out[0,10:35])+10 | |
116 if(predicted_class!=y): | |
117 total_error_count+=1 | |
118 | |
119 if(y<10): | |
120 predicted_class=numpy.argmax(out[0,0:10]) | |
121 if(predicted_class!=y): | |
122 total_error_count+=1 | |
123 if(y>34): | |
124 predicted_class=numpy.argmax(out[0,35:])+35 | |
125 if(predicted_class!=y): | |
126 total_error_count+=1 | |
127 | |
128 print '\t total exemples count: '+str(total_exemple_count) | |
129 print '\t total error count: '+str(total_error_count) | |
130 print '\t percentage of error: '+str(total_error_count*100.0/total_exemple_count*1.0)+' %' | |
131 | |
132 | |
133 def sigmoid(value): | |
134 ## if len(value) > 1: | |
135 ## retour = numpy.zeros(len(value),float) | |
136 ## for i in xrange(len(value)): | |
137 ## retour[i] = (1.0/(1.0+math.exp(-float(value[i])))) | |
138 ## return retour | |
139 ## else: | |
140 ## print len(value) | |
141 return (1.0/(1.0+math.exp(-value))) | |
142 | |
143 if __name__ == '__main__': | |
144 | |
145 args = sys.argv[1:] | |
146 | |
147 if len(args) > 0 and args[0] == 'sigmoid': | |
148 type = 0 | |
149 elif len(args) > 0 and args[0] == 'tanh': | |
150 type = 1 | |
151 | |
152 part = 2 #0=train, 1=valid, 2=test | |
153 | |
154 PATH = '' #Can be changed too if model is not in the current drectory | |
155 | |
156 if os.path.exists(PATH+'params_finetune_NIST.txt'): | |
157 start_time = time.clock() | |
158 print ('\n finetune = NIST ') | |
159 print "NIST DIGITS" | |
160 test_data(PATH+'params_finetune_NIST.txt',datasets.nist_digits(),part=part,type=type) | |
161 print "NIST LOWER CASE" | |
162 test_data(PATH+'params_finetune_NIST.txt',datasets.nist_lower(),part=part,type=type) | |
163 print "NIST UPPER CASE" | |
164 test_data(PATH+'params_finetune_NIST.txt',datasets.nist_upper(),part=part,type=type) | |
165 end_time = time.clock() | |
166 print ('It took %f minutes' %((end_time-start_time)/60.)) | |
167 | |
168 | |
169 if os.path.exists(PATH+'params_finetune_P07.txt'): | |
170 start_time = time.clock() | |
171 print ('\n finetune = P07 ') | |
172 print "NIST DIGITS" | |
173 test_data(PATH+'params_finetune_P07.txt',datasets.nist_digits(),part=part,type=type) | |
174 print "NIST LOWER CASE" | |
175 test_data(PATH+'params_finetune_P07.txt',datasets.nist_lower(),part=part,type=type) | |
176 print "NIST UPPER CASE" | |
177 test_data(PATH+'params_finetune_P07.txt',datasets.nist_upper(),part=part,type=type) | |
178 end_time = time.clock() | |
179 print ('It took %f minutes' %((end_time-start_time)/60.)) | |
180 | |
181 | |
182 if os.path.exists(PATH+'params_finetune_NIST_then_P07.txt'): | |
183 start_time = time.clock() | |
184 print ('\n finetune = NIST then P07') | |
185 print "NIST DIGITS" | |
186 test_data(PATH+'params_finetune_NIST_then_P07.txt',datasets.nist_digits(),part=part,type=type) | |
187 print "NIST LOWER CASE" | |
188 test_data(PATH+'params_finetune_NIST_then_P07.txt',datasets.nist_lower(),part=part,type=type) | |
189 print "NIST UPPER CASE" | |
190 test_data(PATH+'params_finetune_NIST_then_P07.txt',datasets.nist_upper(),part=part,type=type) | |
191 end_time = time.clock() | |
192 print ('It took %f minutes' %((end_time-start_time)/60.)) | |
193 | |
194 if os.path.exists(PATH+'params_finetune_P07_then_NIST.txt'): | |
195 start_time = time.clock() | |
196 print ('\n finetune = P07 then NIST') | |
197 print "NIST DIGITS" | |
198 test_data(PATH+'params_finetune_P07_then_NIST.txt',datasets.nist_digits(),part=part,type=type) | |
199 print "NIST LOWER CASE" | |
200 test_data(PATH+'params_finetune_P07_then_NIST.txt',datasets.nist_lower(),part=part,type=type) | |
201 print "NIST UPPER CASE" | |
202 test_data(PATH+'params_finetune_P07_then_NIST.txt',datasets.nist_upper(),part=part,type=type) | |
203 end_time = time.clock() | |
204 print ('It took %f minutes' %((end_time-start_time)/60.)) | |
205 | |
206 if os.path.exists(PATH+'params_finetune_PNIST07.txt'): | |
207 start_time = time.clock() | |
208 print ('\n finetune = PNIST07') | |
209 print "NIST DIGITS" | |
210 test_data(PATH+'params_finetune_PNIST07.txt',datasets.nist_digits(),part=part,type=type) | |
211 print "NIST LOWER CASE" | |
212 test_data(PATH+'params_finetune_PNIST07.txt',datasets.nist_lower(),part=part,type=type) | |
213 print "NIST UPPER CASE" | |
214 test_data(PATH+'params_finetune_PNIST07.txt',datasets.nist_upper(),part=part,type=type) | |
215 end_time = time.clock() | |
216 print ('It took %f minutes' %((end_time-start_time)/60.)) | |
217 | |
218 if os.path.exists(PATH+'params_finetune_PNIST07_then_NIST.txt'): | |
219 start_time = time.clock() | |
220 print ('\n finetune = PNIST07 then NIST') | |
221 print "NIST DIGITS" | |
222 test_data(PATH+'params_finetune_PNIST07_then_NIST.txt',datasets.nist_digits(),part=part,type=type) | |
223 print "NIST LOWER CASE" | |
224 test_data(PATH+'params_finetune_PNIST07_then_NIST.txt',datasets.nist_lower(),part=part,type=type) | |
225 print "NIST UPPER CASE" | |
226 test_data(PATH+'params_finetune_PNIST07_then_NIST.txt',datasets.nist_upper(),part=part,type=type) | |
227 end_time = time.clock() | |
228 print ('It took %f minutes' %((end_time-start_time)/60.)) | |
229 | |
230 | |
231 | |
232 | |
233 | |
234 | |
235 | |
236 | |
237 | |
238 | |
239 |