changeset 414:3dba84c0fbc1

saving test score from best validation score in db now
author xaviermuller
date Thu, 29 Apr 2010 17:04:12 -0400
parents f2dd75248483
children 1e9788ce1680
files baseline/mlp/mlp_nist.py
diffstat 1 files changed, 79 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/baseline/mlp/mlp_nist.py	Thu Apr 29 16:51:03 2010 -0400
+++ b/baseline/mlp/mlp_nist.py	Thu Apr 29 17:04:12 2010 -0400
@@ -195,6 +195,21 @@
     maj_error_count=0.0
     maj_exemple_count=0.0
     
+    vtotal_error_count=0.0
+    vtotal_exemple_count=0.0
+    
+    vnb_error_count=0.0
+    vnb_exemple_count=0.0
+    
+    vchar_error_count=0.0
+    vchar_exemple_count=0.0
+    
+    vmin_error_count=0.0
+    vmin_exemple_count=0.0
+    
+    vmaj_error_count=0.0
+    vmaj_exemple_count=0.0
+    
     
 
     if data_set==0:
@@ -256,12 +271,73 @@
                 min_error_count = min_error_count +1
             
             
+            
+    vtest_score=0
+    vtemp=0
+    for xt,yt in dataset.valid(1):
+        
+        vtotal_exemple_count = vtotal_exemple_count +1
+        #get activation for layer 1
+        a0=numpy.dot(numpy.transpose(W1),numpy.transpose(xt[0])) + b1
+        #add non linear function to layer 1 activation
+        a0_out=numpy.tanh(a0)
+        
+        #get activation for output layer
+        a1= numpy.dot(numpy.transpose(W2),a0_out) + b2
+        #add non linear function for output activation (softmax)
+        a1_exp = numpy.exp(a1)
+        sum_a1=numpy.sum(a1_exp)
+        a1_out=a1_exp/sum_a1
+        
+        predicted_class=numpy.argmax(a1_out)
+        wanted_class=yt[0]
+        if(predicted_class!=wanted_class):
+            vtotal_error_count = vtotal_error_count +1
+            
+        #treat digit error
+        if(wanted_class<10):
+            vnb_exemple_count=vnb_exemple_count + 1
+            predicted_class=numpy.argmax(a1_out[0:10])
+            if(predicted_class!=wanted_class):
+                vnb_error_count = vnb_error_count +1
+                
+        if(wanted_class>9):
+            vchar_exemple_count=vchar_exemple_count + 1
+            predicted_class=numpy.argmax(a1_out[10:62])+10
+            if((predicted_class!=wanted_class) and ((predicted_class+26)!=wanted_class) and ((predicted_class-26)!=wanted_class)):
+               vchar_error_count = vchar_error_count +1
+               
+        #minuscule
+        if(wanted_class>9 and wanted_class<36):
+            vmaj_exemple_count=vmaj_exemple_count + 1
+            predicted_class=numpy.argmax(a1_out[10:35])+10
+            if(predicted_class!=wanted_class):
+                vmaj_error_count = vmaj_error_count +1
+        #majuscule
+        if(wanted_class>35):
+            vmin_exemple_count=vmin_exemple_count + 1
+            predicted_class=numpy.argmax(a1_out[36:62])+36
+            if(predicted_class!=wanted_class):
+                vmin_error_count = vmin_error_count +1
+            
 
     print (('total error = %f') % ((total_error_count/total_exemple_count)*100.0))
     print (('number error = %f') % ((nb_error_count/nb_exemple_count)*100.0))
     print (('char error = %f') % ((char_error_count/char_exemple_count)*100.0))
     print (('min error = %f') % ((min_error_count/min_exemple_count)*100.0))
     print (('maj error = %f') % ((maj_error_count/maj_exemple_count)*100.0))
+    
+    print (('valid total error = %f') % ((vtotal_error_count/vtotal_exemple_count)*100.0))
+    print (('valid number error = %f') % ((vnb_error_count/vnb_exemple_count)*100.0))
+    print (('valid char error = %f') % ((vchar_error_count/vchar_exemple_count)*100.0))
+    print (('valid min error = %f') % ((vmin_error_count/vmin_exemple_count)*100.0))
+    print (('valid maj error = %f') % ((vmaj_error_count/vmaj_exemple_count)*100.0))
+    
+    print ((' num total = %d,%d') % (total_exemple_count,total_error_count))
+    print ((' num nb = %d,%d') % (nb_exemple_count,nb_error_count))
+    print ((' num min = %d,%d') % (min_exemple_count,min_error_count))
+    print ((' num maj = %d,%d') % (maj_exemple_count,maj_error_count))
+    print ((' num char = %d,%d') % (char_exemple_count,char_error_count))
     return (total_error_count/total_exemple_count)*100.0
     
 
@@ -292,6 +368,7 @@
     #save initial learning rate if classical adaptive lr is used
     initial_lr=learning_rate
     max_div_count=1000
+    optimal_test_error=0
     
     
     total_validation_error_list = []
@@ -482,6 +559,7 @@
 				(epoch, minibatch_index+1,
 				test_score*100.))
                     sys.stdout.flush()
+                    optimal_test_error=test_score
                                     
                 # if the validation error is going up, we are overfitting (or oscillating)
                 # check if we are allowed to continue and if we will adjust the learning rate
@@ -551,7 +629,7 @@
     numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\
     learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list)
     
-    return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60)
+    return (best_training_error*100.0,best_validation_loss * 100.,optimal_test_error*100.,best_iter*batch_size,(end_time-start_time)/60)
 
 
 if __name__ == '__main__':