diff baseline/mlp/mlp_nist.py @ 323:7a7615f940e8

finished code clean up and testing
author xaviermuller
date Thu, 08 Apr 2010 11:01:55 -0400
parents 743907366476
children 1763c64030d1
line wrap: on
line diff
--- a/baseline/mlp/mlp_nist.py	Tue Apr 06 16:00:52 2010 -0400
+++ b/baseline/mlp/mlp_nist.py	Thu Apr 08 11:01:55 2010 -0400
@@ -182,14 +182,19 @@
     
     #save initial learning rate if classical adaptive lr is used
     initial_lr=learning_rate
+    max_div_count=3
+    
     
     total_validation_error_list = []
     total_train_error_list = []
     learning_rate_list=[]
     best_training_error=float('inf');
+    divergence_flag_list=[]
     
     if data_set==0:
     	dataset=datasets.nist_all()
+    elif data_set==1:
+        dataset=datasets.nist_P07()
     
     
     
@@ -250,23 +255,22 @@
    
    
    #conditions for stopping the adaptation:
-   #1) we have reached  nb_max_exemples (this is rounded up to be a multiple of the train size)
+   #1) we have reached  nb_max_exemples (this is rounded up to be a multiple of the train size so we always do at least 1 epoch)
    #2) validation error is going up twice in a row(probable overfitting)
    
    # This means we no longer stop on slow convergence as low learning rates stopped
-   # too fast. 
+   # too fast but instead we will wait for the valid error going up 3 times in a row
+   # We save the curb of the validation error so we can always go back to check on it 
+   # and we save the absolute best model anyway, so we might as well explore
+   # a bit when diverging
    
-    #approximate number of samples in the training set
+    #approximate number of samples in the nist training set
     #this is just to have a validation frequency
-    #roughly proportionnal to the training set
+    #roughly proportionnal to the original nist training set
     n_minibatches        = 650000/batch_size
     
     
-    patience              =nb_max_exemples/batch_size #in units of minibatch
-    patience_increase     = 2     # wait this much longer when a new best is 
-                                  # found
-    improvement_threshold = 0.995 # a relative improvement of this much is 
-                                  # considered significant
+    patience              =2*nb_max_exemples/batch_size #in units of minibatch
     validation_frequency = n_minibatches/4
    
      
@@ -281,17 +285,17 @@
     minibatch_index=0
     epoch=0
     temp=0
+    divergence_flag=0
     
     
     
     if verbose == 1:
-        print 'looking at most at %i exemples' %nb_max_exemples
+        print 'starting training'
     while(minibatch_index*batch_size<nb_max_exemples):
         
         for x, y in dataset.train(batch_size):
 
-            
-            minibatch_index =  minibatch_index + 1
+            #if we are using the classic learning rate deacay, adjust it before training of current mini-batch
             if adaptive_lr==2:
                     classifier.lr.value = tau*initial_lr/(tau+time_n)
         
@@ -300,17 +304,16 @@
             cost_ij = train_model(x,y)
     
             if (minibatch_index+1) % validation_frequency == 0: 
-                
                 #save the current learning rate
                 learning_rate_list.append(classifier.lr.value)
+                divergence_flag_list.append(divergence_flag)
                 
                 # compute the validation error
                 this_validation_loss = 0.
                 temp=0
                 for xv,yv in dataset.valid(1):
                     # sum up the errors for each minibatch
-                    axxa=test_model(xv,yv)
-                    this_validation_loss += axxa
+                    this_validation_loss += test_model(xv,yv)
                     temp=temp+1
                 # get the average by dividing with the number of minibatches
                 this_validation_loss /= temp
@@ -326,9 +329,14 @@
                     # save best validation score and iteration number
                     best_validation_loss = this_validation_loss
                     best_iter = minibatch_index
-                    # reset patience if we are going down again
-                    # so we continue exploring
-                    patience=nb_max_exemples/batch_size
+                    #reset divergence flag
+                    divergence_flag=0
+                    
+                    #save the best model. Overwrite the current saved best model so
+                    #we only keep the best
+                    numpy.savez('best_model.npy', config=configuration, W1=classifier.W1.value, W2=classifier.W2.value, b1=classifier.b1.value,\
+                    b2=classifier.b2.value, minibatch_index=minibatch_index)
+
                     # test it on the test set
                     test_score = 0.
                     temp =0
@@ -343,21 +351,24 @@
                                     test_score*100.))
                                     
                 # if the validation error is going up, we are overfitting (or oscillating)
-                # stop converging but run at least to next validation
-                # to check overfitting or ocsillation
-                # the saved weights of the model will be a bit off in that case
+                # check if we are allowed to continue and if we will adjust the learning rate
                 elif this_validation_loss >= best_validation_loss:
+                   
+                    
+                    # In non-classic learning rate decay, we modify the weight only when
+                    # validation error is going up
+                    if adaptive_lr==1:
+                        classifier.lr.value=classifier.lr.value*lr_t2_factor
+                           
+                   
+                    #cap the patience so we are allowed to diverge max_div_count times
+                    #if we are going up max_div_count in a row, we will stop immediatelty by modifying the patience
+                    divergence_flag = divergence_flag +1
+                    
+                    
                     #calculate the test error at this point and exit
                     # test it on the test set
-                    # however, if adaptive_lr is true, try reducing the lr to
-                    # get us out of an oscilliation
-                    if adaptive_lr==1:
-                        classifier.lr.value=classifier.lr.value*lr_t2_factor
-    
                     test_score = 0.
-                    #cap the patience so we are allowed one more validation error
-                    #calculation before aborting
-                    patience = minibatch_index+validation_frequency+1
                     temp=0
                     for xt,yt in dataset.test(batch_size):
                         test_score += test_model(xt,yt)
@@ -372,13 +383,22 @@
                                     
                     
     
-    
-            if minibatch_index>patience:
-                print 'we have diverged'
+            # check early stop condition
+            if divergence_flag==max_div_count:
+                minibatch_index=nb_max_exemples
+                print 'we have diverged, early stopping kicks in'
+                break
+            
+            #check if we have seen enough exemples
+            #force one epoch at least
+            if epoch>0 and minibatch_index*batch_size>nb_max_exemples:
                 break
     
     
             time_n= time_n + batch_size
+            minibatch_index =  minibatch_index + 1
+            
+        # we have finished looping through the training set
         epoch = epoch+1
     end_time = time.clock()
     if verbose == 1:
@@ -391,7 +411,7 @@
     #save the model and the weights
     numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value)
     numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\
-    learning_rate_list=learning_rate_list)
+    learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list)
     
     return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60)
 
@@ -410,7 +430,8 @@
 										train_labels = state.train_labels,\
 										test_data = state.test_data,\
 										test_labels = state.test_labels,\
-										lr_t2_factor=state.lr_t2_factor)
+										lr_t2_factor=state.lr_t2_factor,\
+                                                                                data_set=state.data_set)
     state.train_error=train_error
     state.validation_error=validation_error
     state.test_error=test_error