# HG changeset patch # User xaviermuller # Date 1270738915 14400 # Node ID 7a7615f940e84a6e66e993d1e882ee75fcedf87e # Parent 7439073664766a050410677633b73a019b49c046 finished code clean up and testing diff -r 743907366476 -r 7a7615f940e8 baseline/mlp/mlp_nist.py --- a/baseline/mlp/mlp_nist.py Tue Apr 06 16:00:52 2010 -0400 +++ b/baseline/mlp/mlp_nist.py Thu Apr 08 11:01:55 2010 -0400 @@ -182,14 +182,19 @@ #save initial learning rate if classical adaptive lr is used initial_lr=learning_rate + max_div_count=3 + total_validation_error_list = [] total_train_error_list = [] learning_rate_list=[] best_training_error=float('inf'); + divergence_flag_list=[] if data_set==0: dataset=datasets.nist_all() + elif data_set==1: + dataset=datasets.nist_P07() @@ -250,23 +255,22 @@ #conditions for stopping the adaptation: - #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size) + #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size so we always do at least 1 epoch) #2) validation error is going up twice in a row(probable overfitting) # This means we no longer stop on slow convergence as low learning rates stopped - # too fast. + # too fast but instead we will wait for the valid error going up 3 times in a row + # We save the curb of the validation error so we can always go back to check on it + # and we save the absolute best model anyway, so we might as well explore + # a bit when diverging - #approximate number of samples in the training set + #approximate number of samples in the nist training set #this is just to have a validation frequency - #roughly proportionnal to the training set + #roughly proportionnal to the original nist training set n_minibatches = 650000/batch_size - patience =nb_max_exemples/batch_size #in units of minibatch - patience_increase = 2 # wait this much longer when a new best is - # found - improvement_threshold = 0.995 # a relative improvement of this much is - # considered significant + patience =2*nb_max_exemples/batch_size #in units of minibatch validation_frequency = n_minibatches/4 @@ -281,17 +285,17 @@ minibatch_index=0 epoch=0 temp=0 + divergence_flag=0 if verbose == 1: - print 'looking at most at %i exemples' %nb_max_exemples + print 'starting training' while(minibatch_index*batch_size= best_validation_loss: + + + # In non-classic learning rate decay, we modify the weight only when + # validation error is going up + if adaptive_lr==1: + classifier.lr.value=classifier.lr.value*lr_t2_factor + + + #cap the patience so we are allowed to diverge max_div_count times + #if we are going up max_div_count in a row, we will stop immediatelty by modifying the patience + divergence_flag = divergence_flag +1 + + #calculate the test error at this point and exit # test it on the test set - # however, if adaptive_lr is true, try reducing the lr to - # get us out of an oscilliation - if adaptive_lr==1: - classifier.lr.value=classifier.lr.value*lr_t2_factor - test_score = 0. - #cap the patience so we are allowed one more validation error - #calculation before aborting - patience = minibatch_index+validation_frequency+1 temp=0 for xt,yt in dataset.test(batch_size): test_score += test_model(xt,yt) @@ -372,13 +383,22 @@ - - if minibatch_index>patience: - print 'we have diverged' + # check early stop condition + if divergence_flag==max_div_count: + minibatch_index=nb_max_exemples + print 'we have diverged, early stopping kicks in' + break + + #check if we have seen enough exemples + #force one epoch at least + if epoch>0 and minibatch_index*batch_size>nb_max_exemples: break time_n= time_n + batch_size + minibatch_index = minibatch_index + 1 + + # we have finished looping through the training set epoch = epoch+1 end_time = time.clock() if verbose == 1: @@ -391,7 +411,7 @@ #save the model and the weights numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value) numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\ - learning_rate_list=learning_rate_list) + learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list) return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60) @@ -410,7 +430,8 @@ train_labels = state.train_labels,\ test_data = state.test_data,\ test_labels = state.test_labels,\ - lr_t2_factor=state.lr_t2_factor) + lr_t2_factor=state.lr_t2_factor,\ + data_set=state.data_set) state.train_error=train_error state.validation_error=validation_error state.test_error=test_error