Mercurial > ift6266
comparison baseline/mlp/mlp_get_error_from_model.py @ 237:9b6e0af062af
corrected a bug in jobman interface
author | xaviermuller |
---|---|
date | Mon, 15 Mar 2010 10:09:50 -0400 |
parents | e390b0454515 |
children |
comparison
equal
deleted
inserted
replaced
236:7be1f086a89e | 237:9b6e0af062af |
---|---|
6 import time | 6 import time |
7 import pylearn | 7 import pylearn |
8 from pylearn.io import filetensor as ft | 8 from pylearn.io import filetensor as ft |
9 | 9 |
10 data_path = '/data/lisa/data/nist/by_class/' | 10 data_path = '/data/lisa/data/nist/by_class/' |
11 test_data = 'all/all_test_data.ft' | 11 test_data = 'all/all_train_data.ft' |
12 test_labels = 'all/all_test_labels.ft' | 12 test_labels = 'all/all_train_labels.ft' |
13 | 13 |
14 def read_test_data(mlp_model): | 14 def read_test_data(mlp_model): |
15 | 15 |
16 | 16 |
17 #read the data | 17 #read the data |
38 | 38 |
39 W1=everything[0] | 39 W1=everything[0] |
40 b1=everything[1] | 40 b1=everything[1] |
41 W2=everything[2] | 41 W2=everything[2] |
42 b2=everything[3] | 42 b2=everything[3] |
43 test_data=everything[4]/255.0 | 43 test_data=everything[4] |
44 test_labels=everything[5] | 44 test_labels=everything[5] |
45 total_error_count=0 | 45 total_error_count=0 |
46 total_exemple_count=0 | 46 total_exemple_count=0 |
47 | 47 |
48 nb_error_count=0 | 48 nb_error_count=0 |
58 maj_exemple_count=0 | 58 maj_exemple_count=0 |
59 | 59 |
60 for i in range(test_labels.size): | 60 for i in range(test_labels.size): |
61 total_exemple_count = total_exemple_count +1 | 61 total_exemple_count = total_exemple_count +1 |
62 #get activation for layer 1 | 62 #get activation for layer 1 |
63 a0=np.dot(np.transpose(W1),np.transpose(test_data[i])) + b1 | 63 a0=np.dot(np.transpose(W1),np.transpose(test_data[i]/255.0)) + b1 |
64 #add non linear function to layer 1 activation | 64 #add non linear function to layer 1 activation |
65 a0_out=np.tanh(a0) | 65 a0_out=np.tanh(a0) |
66 | 66 |
67 #get activation for output layer | 67 #get activation for output layer |
68 a1= np.dot(np.transpose(W2),a0_out) + b2 | 68 a1= np.dot(np.transpose(W2),a0_out) + b2 |
76 | 76 |
77 if(predicted_class!=wanted_class): | 77 if(predicted_class!=wanted_class): |
78 total_error_count = total_error_count +1 | 78 total_error_count = total_error_count +1 |
79 | 79 |
80 #get grouped based error | 80 #get grouped based error |
81 #with a priori | |
82 # if(wanted_class>9 and wanted_class<35): | |
83 # min_exemple_count=min_exemple_count+1 | |
84 # predicted_class=np.argmax(a1_out[10:35])+10 | |
85 # if(predicted_class!=wanted_class): | |
86 # min_error_count=min_error_count+1 | |
87 # if(wanted_class<10): | |
88 # nb_exemple_count=nb_exemple_count+1 | |
89 # predicted_class=np.argmax(a1_out[0:10]) | |
90 # if(predicted_class!=wanted_class): | |
91 # nb_error_count=nb_error_count+1 | |
92 # if(wanted_class>34): | |
93 # maj_exemple_count=maj_exemple_count+1 | |
94 # predicted_class=np.argmax(a1_out[35:])+35 | |
95 # if(predicted_class!=wanted_class): | |
96 # maj_error_count=maj_error_count+1 | |
97 # | |
98 # if(wanted_class>9): | |
99 # char_exemple_count=char_exemple_count+1 | |
100 # predicted_class=np.argmax(a1_out[10:])+10 | |
101 # if(predicted_class!=wanted_class): | |
102 # char_error_count=char_error_count+1 | |
103 | |
104 | |
105 | |
106 #get grouped based error | |
107 #with no a priori | |
81 if(wanted_class>9 and wanted_class<35): | 108 if(wanted_class>9 and wanted_class<35): |
82 min_exemple_count=min_exemple_count+1 | 109 min_exemple_count=min_exemple_count+1 |
83 predicted_class=np.argmax(a1_out) | 110 predicted_class=np.argmax(a1_out) |
84 if(predicted_class!=wanted_class): | 111 if(predicted_class!=wanted_class): |
85 min_error_count=min_error_count+1 | 112 min_error_count=min_error_count+1 |
86 elif(wanted_class<10): | 113 if(wanted_class<10): |
87 nb_exemple_count=nb_exemple_count+1 | 114 nb_exemple_count=nb_exemple_count+1 |
88 predicted_class=np.argmax(a1_out) | 115 predicted_class=np.argmax(a1_out) |
89 if(predicted_class!=wanted_class): | 116 if(predicted_class!=wanted_class): |
90 nb_error_count=nb_error_count+1 | 117 nb_error_count=nb_error_count+1 |
91 elif(wanted_class>34): | 118 if(wanted_class>34): |
92 maj_exemple_count=maj_exemple_count+1 | 119 maj_exemple_count=maj_exemple_count+1 |
93 predicted_class=np.argmax(a1_out) | 120 predicted_class=np.argmax(a1_out) |
94 if(predicted_class!=wanted_class): | 121 if(predicted_class!=wanted_class): |
95 maj_error_count=maj_error_count+1 | 122 maj_error_count=maj_error_count+1 |
96 | 123 |