comparison baseline/mlp/mlp_get_error_from_model.py @ 237:9b6e0af062af

corrected a bug in jobman interface
author xaviermuller
date Mon, 15 Mar 2010 10:09:50 -0400
parents e390b0454515
children
comparison
equal deleted inserted replaced
236:7be1f086a89e 237:9b6e0af062af
6 import time 6 import time
7 import pylearn 7 import pylearn
8 from pylearn.io import filetensor as ft 8 from pylearn.io import filetensor as ft
9 9
10 data_path = '/data/lisa/data/nist/by_class/' 10 data_path = '/data/lisa/data/nist/by_class/'
11 test_data = 'all/all_test_data.ft' 11 test_data = 'all/all_train_data.ft'
12 test_labels = 'all/all_test_labels.ft' 12 test_labels = 'all/all_train_labels.ft'
13 13
14 def read_test_data(mlp_model): 14 def read_test_data(mlp_model):
15 15
16 16
17 #read the data 17 #read the data
38 38
39 W1=everything[0] 39 W1=everything[0]
40 b1=everything[1] 40 b1=everything[1]
41 W2=everything[2] 41 W2=everything[2]
42 b2=everything[3] 42 b2=everything[3]
43 test_data=everything[4]/255.0 43 test_data=everything[4]
44 test_labels=everything[5] 44 test_labels=everything[5]
45 total_error_count=0 45 total_error_count=0
46 total_exemple_count=0 46 total_exemple_count=0
47 47
48 nb_error_count=0 48 nb_error_count=0
58 maj_exemple_count=0 58 maj_exemple_count=0
59 59
60 for i in range(test_labels.size): 60 for i in range(test_labels.size):
61 total_exemple_count = total_exemple_count +1 61 total_exemple_count = total_exemple_count +1
62 #get activation for layer 1 62 #get activation for layer 1
63 a0=np.dot(np.transpose(W1),np.transpose(test_data[i])) + b1 63 a0=np.dot(np.transpose(W1),np.transpose(test_data[i]/255.0)) + b1
64 #add non linear function to layer 1 activation 64 #add non linear function to layer 1 activation
65 a0_out=np.tanh(a0) 65 a0_out=np.tanh(a0)
66 66
67 #get activation for output layer 67 #get activation for output layer
68 a1= np.dot(np.transpose(W2),a0_out) + b2 68 a1= np.dot(np.transpose(W2),a0_out) + b2
76 76
77 if(predicted_class!=wanted_class): 77 if(predicted_class!=wanted_class):
78 total_error_count = total_error_count +1 78 total_error_count = total_error_count +1
79 79
80 #get grouped based error 80 #get grouped based error
81 #with a priori
82 # if(wanted_class>9 and wanted_class<35):
83 # min_exemple_count=min_exemple_count+1
84 # predicted_class=np.argmax(a1_out[10:35])+10
85 # if(predicted_class!=wanted_class):
86 # min_error_count=min_error_count+1
87 # if(wanted_class<10):
88 # nb_exemple_count=nb_exemple_count+1
89 # predicted_class=np.argmax(a1_out[0:10])
90 # if(predicted_class!=wanted_class):
91 # nb_error_count=nb_error_count+1
92 # if(wanted_class>34):
93 # maj_exemple_count=maj_exemple_count+1
94 # predicted_class=np.argmax(a1_out[35:])+35
95 # if(predicted_class!=wanted_class):
96 # maj_error_count=maj_error_count+1
97 #
98 # if(wanted_class>9):
99 # char_exemple_count=char_exemple_count+1
100 # predicted_class=np.argmax(a1_out[10:])+10
101 # if(predicted_class!=wanted_class):
102 # char_error_count=char_error_count+1
103
104
105
106 #get grouped based error
107 #with no a priori
81 if(wanted_class>9 and wanted_class<35): 108 if(wanted_class>9 and wanted_class<35):
82 min_exemple_count=min_exemple_count+1 109 min_exemple_count=min_exemple_count+1
83 predicted_class=np.argmax(a1_out) 110 predicted_class=np.argmax(a1_out)
84 if(predicted_class!=wanted_class): 111 if(predicted_class!=wanted_class):
85 min_error_count=min_error_count+1 112 min_error_count=min_error_count+1
86 elif(wanted_class<10): 113 if(wanted_class<10):
87 nb_exemple_count=nb_exemple_count+1 114 nb_exemple_count=nb_exemple_count+1
88 predicted_class=np.argmax(a1_out) 115 predicted_class=np.argmax(a1_out)
89 if(predicted_class!=wanted_class): 116 if(predicted_class!=wanted_class):
90 nb_error_count=nb_error_count+1 117 nb_error_count=nb_error_count+1
91 elif(wanted_class>34): 118 if(wanted_class>34):
92 maj_exemple_count=maj_exemple_count+1 119 maj_exemple_count=maj_exemple_count+1
93 predicted_class=np.argmax(a1_out) 120 predicted_class=np.argmax(a1_out)
94 if(predicted_class!=wanted_class): 121 if(predicted_class!=wanted_class):
95 maj_error_count=maj_error_count+1 122 maj_error_count=maj_error_count+1
96 123