comparison baseline/mlp/mlp_get_error_from_model.py @ 212:e390b0454515

added classic lr time decay and py code to calculate the error based on a saved model
author xaviermuller
date Wed, 10 Mar 2010 16:17:59 -0500
parents
children 9b6e0af062af
comparison
equal deleted inserted replaced
211:476da2ba6a12 212:e390b0454515
1 __docformat__ = 'restructedtext en'
2
3 import pdb
4 import numpy as np
5 import pylab
6 import time
7 import pylearn
8 from pylearn.io import filetensor as ft
9
10 data_path = '/data/lisa/data/nist/by_class/'
11 test_data = 'all/all_test_data.ft'
12 test_labels = 'all/all_test_labels.ft'
13
14 def read_test_data(mlp_model):
15
16
17 #read the data
18 h = open(data_path+test_data)
19 i= open(data_path+test_labels)
20 raw_test_data = ft.read(h)
21 raw_test_labels = ft.read(i)
22 i.close()
23 h.close()
24
25 #read the model chosen
26 a=np.load(mlp_model)
27 W1=a['W1']
28 W2=a['W2']
29 b1=a['b1']
30 b2=a['b2']
31
32 return (W1,b1,W2,b2,raw_test_data,raw_test_labels)
33
34
35
36
37 def get_total_test_error(everything):
38
39 W1=everything[0]
40 b1=everything[1]
41 W2=everything[2]
42 b2=everything[3]
43 test_data=everything[4]/255.0
44 test_labels=everything[5]
45 total_error_count=0
46 total_exemple_count=0
47
48 nb_error_count=0
49 nb_exemple_count=0
50
51 char_error_count=0
52 char_exemple_count=0
53
54 min_error_count=0
55 min_exemple_count=0
56
57 maj_error_count=0
58 maj_exemple_count=0
59
60 for i in range(test_labels.size):
61 total_exemple_count = total_exemple_count +1
62 #get activation for layer 1
63 a0=np.dot(np.transpose(W1),np.transpose(test_data[i])) + b1
64 #add non linear function to layer 1 activation
65 a0_out=np.tanh(a0)
66
67 #get activation for output layer
68 a1= np.dot(np.transpose(W2),a0_out) + b2
69 #add non linear function for output activation (softmax)
70 a1_exp = np.exp(a1)
71 sum_a1=np.sum(a1_exp)
72 a1_out=a1_exp/sum_a1
73
74 predicted_class=np.argmax(a1_out)
75 wanted_class=test_labels[i]
76
77 if(predicted_class!=wanted_class):
78 total_error_count = total_error_count +1
79
80 #get grouped based error
81 if(wanted_class>9 and wanted_class<35):
82 min_exemple_count=min_exemple_count+1
83 predicted_class=np.argmax(a1_out)
84 if(predicted_class!=wanted_class):
85 min_error_count=min_error_count+1
86 elif(wanted_class<10):
87 nb_exemple_count=nb_exemple_count+1
88 predicted_class=np.argmax(a1_out)
89 if(predicted_class!=wanted_class):
90 nb_error_count=nb_error_count+1
91 elif(wanted_class>34):
92 maj_exemple_count=maj_exemple_count+1
93 predicted_class=np.argmax(a1_out)
94 if(predicted_class!=wanted_class):
95 maj_error_count=maj_error_count+1
96
97 if(wanted_class>9):
98 char_exemple_count=char_exemple_count+1
99 predicted_class=np.argmax(a1_out)
100 if(predicted_class!=wanted_class):
101 char_error_count=char_error_count+1
102
103
104 #convert to float
105 return ( total_exemple_count,nb_exemple_count,char_exemple_count,min_exemple_count,maj_exemple_count,\
106 total_error_count,nb_error_count,char_error_count,min_error_count,maj_error_count,\
107 total_error_count*100.0/total_exemple_count*1.0,\
108 nb_error_count*100.0/nb_exemple_count*1.0,\
109 char_error_count*100.0/char_exemple_count*1.0,\
110 min_error_count*100.0/min_exemple_count*1.0,\
111 maj_error_count*100.0/maj_exemple_count*1.0)
112
113
114
115
116
117
118
119
120
121
122
123
124