Mercurial > ift6266
comparison baseline/mlp/mlp_get_error_from_model.py @ 212:e390b0454515
added classic lr time decay and py code to calculate the error based on a saved model
author | xaviermuller |
---|---|
date | Wed, 10 Mar 2010 16:17:59 -0500 |
parents | |
children | 9b6e0af062af |
comparison
equal
deleted
inserted
replaced
211:476da2ba6a12 | 212:e390b0454515 |
---|---|
1 __docformat__ = 'restructedtext en' | |
2 | |
3 import pdb | |
4 import numpy as np | |
5 import pylab | |
6 import time | |
7 import pylearn | |
8 from pylearn.io import filetensor as ft | |
9 | |
10 data_path = '/data/lisa/data/nist/by_class/' | |
11 test_data = 'all/all_test_data.ft' | |
12 test_labels = 'all/all_test_labels.ft' | |
13 | |
14 def read_test_data(mlp_model): | |
15 | |
16 | |
17 #read the data | |
18 h = open(data_path+test_data) | |
19 i= open(data_path+test_labels) | |
20 raw_test_data = ft.read(h) | |
21 raw_test_labels = ft.read(i) | |
22 i.close() | |
23 h.close() | |
24 | |
25 #read the model chosen | |
26 a=np.load(mlp_model) | |
27 W1=a['W1'] | |
28 W2=a['W2'] | |
29 b1=a['b1'] | |
30 b2=a['b2'] | |
31 | |
32 return (W1,b1,W2,b2,raw_test_data,raw_test_labels) | |
33 | |
34 | |
35 | |
36 | |
37 def get_total_test_error(everything): | |
38 | |
39 W1=everything[0] | |
40 b1=everything[1] | |
41 W2=everything[2] | |
42 b2=everything[3] | |
43 test_data=everything[4]/255.0 | |
44 test_labels=everything[5] | |
45 total_error_count=0 | |
46 total_exemple_count=0 | |
47 | |
48 nb_error_count=0 | |
49 nb_exemple_count=0 | |
50 | |
51 char_error_count=0 | |
52 char_exemple_count=0 | |
53 | |
54 min_error_count=0 | |
55 min_exemple_count=0 | |
56 | |
57 maj_error_count=0 | |
58 maj_exemple_count=0 | |
59 | |
60 for i in range(test_labels.size): | |
61 total_exemple_count = total_exemple_count +1 | |
62 #get activation for layer 1 | |
63 a0=np.dot(np.transpose(W1),np.transpose(test_data[i])) + b1 | |
64 #add non linear function to layer 1 activation | |
65 a0_out=np.tanh(a0) | |
66 | |
67 #get activation for output layer | |
68 a1= np.dot(np.transpose(W2),a0_out) + b2 | |
69 #add non linear function for output activation (softmax) | |
70 a1_exp = np.exp(a1) | |
71 sum_a1=np.sum(a1_exp) | |
72 a1_out=a1_exp/sum_a1 | |
73 | |
74 predicted_class=np.argmax(a1_out) | |
75 wanted_class=test_labels[i] | |
76 | |
77 if(predicted_class!=wanted_class): | |
78 total_error_count = total_error_count +1 | |
79 | |
80 #get grouped based error | |
81 if(wanted_class>9 and wanted_class<35): | |
82 min_exemple_count=min_exemple_count+1 | |
83 predicted_class=np.argmax(a1_out) | |
84 if(predicted_class!=wanted_class): | |
85 min_error_count=min_error_count+1 | |
86 elif(wanted_class<10): | |
87 nb_exemple_count=nb_exemple_count+1 | |
88 predicted_class=np.argmax(a1_out) | |
89 if(predicted_class!=wanted_class): | |
90 nb_error_count=nb_error_count+1 | |
91 elif(wanted_class>34): | |
92 maj_exemple_count=maj_exemple_count+1 | |
93 predicted_class=np.argmax(a1_out) | |
94 if(predicted_class!=wanted_class): | |
95 maj_error_count=maj_error_count+1 | |
96 | |
97 if(wanted_class>9): | |
98 char_exemple_count=char_exemple_count+1 | |
99 predicted_class=np.argmax(a1_out) | |
100 if(predicted_class!=wanted_class): | |
101 char_error_count=char_error_count+1 | |
102 | |
103 | |
104 #convert to float | |
105 return ( total_exemple_count,nb_exemple_count,char_exemple_count,min_exemple_count,maj_exemple_count,\ | |
106 total_error_count,nb_error_count,char_error_count,min_error_count,maj_error_count,\ | |
107 total_error_count*100.0/total_exemple_count*1.0,\ | |
108 nb_error_count*100.0/nb_exemple_count*1.0,\ | |
109 char_error_count*100.0/char_exemple_count*1.0,\ | |
110 min_error_count*100.0/min_exemple_count*1.0,\ | |
111 maj_error_count*100.0/maj_exemple_count*1.0) | |
112 | |
113 | |
114 | |
115 | |
116 | |
117 | |
118 | |
119 | |
120 | |
121 | |
122 | |
123 | |
124 |