changeset 216:c89004f9cab2

merge
author Dumitru Erhan <dumitru.erhan@gmail.com>
date Wed, 10 Mar 2010 17:08:27 -0500
parents 334d2444000d (current diff) e390b0454515 (diff)
children de3aef84714a
files
diffstat 3 files changed, 153 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/baseline/mlp/mlp_get_error_from_model.py	Wed Mar 10 17:08:27 2010 -0500
@@ -0,0 +1,124 @@
+__docformat__ = 'restructedtext en'
+
+import pdb
+import numpy as np
+import pylab
+import time 
+import pylearn
+from pylearn.io import filetensor as ft
+
+data_path = '/data/lisa/data/nist/by_class/'
+test_data = 'all/all_test_data.ft'
+test_labels = 'all/all_test_labels.ft'
+
+def read_test_data(mlp_model):
+    
+    
+    #read the data
+    h = open(data_path+test_data)
+    i= open(data_path+test_labels)
+    raw_test_data = ft.read(h)
+    raw_test_labels = ft.read(i)
+    i.close()
+    h.close()
+    
+    #read the model chosen
+    a=np.load(mlp_model)
+    W1=a['W1']
+    W2=a['W2']
+    b1=a['b1']
+    b2=a['b2']
+    
+    return (W1,b1,W2,b2,raw_test_data,raw_test_labels)
+    
+    
+    
+
+def get_total_test_error(everything):
+    
+    W1=everything[0]
+    b1=everything[1]
+    W2=everything[2]
+    b2=everything[3]
+    test_data=everything[4]/255.0
+    test_labels=everything[5]
+    total_error_count=0
+    total_exemple_count=0
+    
+    nb_error_count=0
+    nb_exemple_count=0
+    
+    char_error_count=0
+    char_exemple_count=0
+    
+    min_error_count=0
+    min_exemple_count=0
+    
+    maj_error_count=0
+    maj_exemple_count=0
+    
+    for i in range(test_labels.size):
+        total_exemple_count = total_exemple_count +1
+        #get activation for layer 1
+        a0=np.dot(np.transpose(W1),np.transpose(test_data[i])) + b1
+        #add non linear function to layer 1 activation
+        a0_out=np.tanh(a0)
+        
+        #get activation for output layer
+        a1= np.dot(np.transpose(W2),a0_out) + b2
+        #add non linear function for output activation (softmax)
+        a1_exp = np.exp(a1)
+        sum_a1=np.sum(a1_exp)
+        a1_out=a1_exp/sum_a1
+        
+        predicted_class=np.argmax(a1_out)
+        wanted_class=test_labels[i]
+        
+        if(predicted_class!=wanted_class):
+            total_error_count = total_error_count +1
+            
+        #get grouped based error
+        if(wanted_class>9 and wanted_class<35):
+            min_exemple_count=min_exemple_count+1
+            predicted_class=np.argmax(a1_out)
+            if(predicted_class!=wanted_class):
+		min_error_count=min_error_count+1
+        elif(wanted_class<10):
+            nb_exemple_count=nb_exemple_count+1
+            predicted_class=np.argmax(a1_out)
+            if(predicted_class!=wanted_class):
+                nb_error_count=nb_error_count+1
+        elif(wanted_class>34):
+            maj_exemple_count=maj_exemple_count+1
+            predicted_class=np.argmax(a1_out)
+            if(predicted_class!=wanted_class):
+                maj_error_count=maj_error_count+1
+                
+        if(wanted_class>9):
+            char_exemple_count=char_exemple_count+1
+            predicted_class=np.argmax(a1_out)
+            if(predicted_class!=wanted_class):
+                char_error_count=char_error_count+1
+    
+    
+    #convert to float 
+    return ( total_exemple_count,nb_exemple_count,char_exemple_count,min_exemple_count,maj_exemple_count,\
+            total_error_count,nb_error_count,char_error_count,min_error_count,maj_error_count,\
+            total_error_count*100.0/total_exemple_count*1.0,\
+            nb_error_count*100.0/nb_exemple_count*1.0,\
+            char_error_count*100.0/char_exemple_count*1.0,\
+            min_error_count*100.0/min_exemple_count*1.0,\
+            maj_error_count*100.0/maj_exemple_count*1.0)
+            
+            
+    
+    
+    
+    
+    
+    
+    
+    
+    
+    
+ 
\ No newline at end of file
--- a/baseline/mlp/mlp_nist.py	Wed Mar 10 13:48:16 2010 -0500
+++ b/baseline/mlp/mlp_nist.py	Wed Mar 10 17:08:27 2010 -0500
@@ -31,6 +31,7 @@
 import time 
 import theano.tensor.nnet
 import pylearn
+import theano,pylearn.version
 from pylearn.io import filetensor as ft
 
 data_path = '/data/lisa/data/nist/by_class/'
@@ -174,17 +175,22 @@
                         nb_max_exemples=1000000,\
                         batch_size=20,\
                         nb_hidden = 500,\
-                        nb_targets = 62):
+                        nb_targets = 62,
+			tau=1e6):
    
     
     configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr]
     
+    #save initial learning rate if classical adaptive lr is used
+    initial_lr=learning_rate
+    
     total_validation_error_list = []
     total_train_error_list = []
     learning_rate_list=[]
     best_training_error=float('inf');
     
     
+    
    
     f = open(data_path+train_data)
     g= open(data_path+train_labels)
@@ -315,6 +321,8 @@
     n_iter = nb_max_exemples/batch_size  # nb of max times we are allowed to run through all exemples
     n_iter = n_iter/n_minibatches + 1 #round up
     n_iter=max(1,n_iter) # run at least once on short debug call
+    time_n=0 #in unit of exemples
+    
     
    
     if verbose == True:
@@ -325,6 +333,9 @@
         epoch           = iter / n_minibatches
         minibatch_index =  iter % n_minibatches
         
+	
+	if adaptive_lr==2:
+	    classifier.lr.value = tau*initial_lr/(tau+time_n)
       
         
         # get the minibatches corresponding to `iter` modulo
@@ -364,6 +375,8 @@
                 print('epoch %i, minibatch %i/%i, validation error %f, training error %f %%' % \
                     (epoch, minibatch_index+1, n_minibatches, \
                         this_validation_loss*100.,this_train_loss*100))
+		print 'learning rate = %f' %classifier.lr.value
+		print 'time  = %i' %time_n
                         
                         
             #save the learning rate
@@ -425,6 +438,7 @@
             break
 
 
+    	time_n= time_n + batch_size
     end_time = time.clock()
     if verbose == True:
         print(('Optimization complete. Best validation score of %f %% '
@@ -448,7 +462,8 @@
     (train_error,validation_error,test_error,nb_exemples,time)=mlp_full_nist(learning_rate=state.learning_rate,\
                                                                 nb_max_exemples=state.nb_max_exemples,\
                                                                 nb_hidden=state.nb_hidden,\
-                                                                adaptive_lr=state.adaptive_lr)
+                                                                adaptive_lr=state.adaptive_lr,\
+								tau=tau)
     state.train_error=train_error
     state.validation_error=validation_error
     state.test_error=test_error
--- a/datasets/defs.py	Wed Mar 10 13:48:16 2010 -0500
+++ b/datasets/defs.py	Wed Mar 10 17:08:27 2010 -0500
@@ -1,4 +1,5 @@
-__all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr']
+__all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr', 
+           'nist_P07']
 
 from ftfile import FTDataSet
 import theano
@@ -35,4 +36,13 @@
                 test_data = [DATA_PATH+'ocr_test_data.ft'],
                 test_lbl = [DATA_PATH+'ocr_test_labels.ft'],
                 valid_data = [DATA_PATH+'ocr_valid_data.ft'],
-                valid_lbl = [DATA_PATH+'ocr_valid_labels.ft'])
+                valid_lbl = [DATA_PATH+'ocr_valid_labels.ft'],
+                indtype=theano.config.floatX, inscale=255.)
+
+nist_P07 = FTDataSet(train_data = [DATA_PATH+'data/P07_train'+str(i)+'_data.ft' for i in range(100)],
+                     train_lbl = [DATA_PATH+'data/P07_train'+str(i)+'_labels.ft' for i in range(100)],
+                     test_data = [DATA_PATH+'data/P07_test_data.ft'],
+                     test_lbl = [DATA_PATH+'data/P07_test_labels.ft'],
+                     valid_data = [DATA_PATH+'data/P07_valid_data.ft'],
+                     valid_lbl = [DATA_PATH+'data/P07_valid_labels.ft'],
+                     indtype=theano.config.floatX, inscale=255.)