Mercurial > ift6266
changeset 323:7a7615f940e8
finished code clean up and testing
author | xaviermuller |
---|---|
date | Thu, 08 Apr 2010 11:01:55 -0400 |
parents | 743907366476 |
children | 1763c64030d1 |
files | baseline/mlp/mlp_nist.py |
diffstat | 1 files changed, 55 insertions(+), 34 deletions(-) [+] |
line wrap: on
line diff
--- a/baseline/mlp/mlp_nist.py Tue Apr 06 16:00:52 2010 -0400 +++ b/baseline/mlp/mlp_nist.py Thu Apr 08 11:01:55 2010 -0400 @@ -182,14 +182,19 @@ #save initial learning rate if classical adaptive lr is used initial_lr=learning_rate + max_div_count=3 + total_validation_error_list = [] total_train_error_list = [] learning_rate_list=[] best_training_error=float('inf'); + divergence_flag_list=[] if data_set==0: dataset=datasets.nist_all() + elif data_set==1: + dataset=datasets.nist_P07() @@ -250,23 +255,22 @@ #conditions for stopping the adaptation: - #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size) + #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size so we always do at least 1 epoch) #2) validation error is going up twice in a row(probable overfitting) # This means we no longer stop on slow convergence as low learning rates stopped - # too fast. + # too fast but instead we will wait for the valid error going up 3 times in a row + # We save the curb of the validation error so we can always go back to check on it + # and we save the absolute best model anyway, so we might as well explore + # a bit when diverging - #approximate number of samples in the training set + #approximate number of samples in the nist training set #this is just to have a validation frequency - #roughly proportionnal to the training set + #roughly proportionnal to the original nist training set n_minibatches = 650000/batch_size - patience =nb_max_exemples/batch_size #in units of minibatch - patience_increase = 2 # wait this much longer when a new best is - # found - improvement_threshold = 0.995 # a relative improvement of this much is - # considered significant + patience =2*nb_max_exemples/batch_size #in units of minibatch validation_frequency = n_minibatches/4 @@ -281,17 +285,17 @@ minibatch_index=0 epoch=0 temp=0 + divergence_flag=0 if verbose == 1: - print 'looking at most at %i exemples' %nb_max_exemples + print 'starting training' while(minibatch_index*batch_size<nb_max_exemples): for x, y in dataset.train(batch_size): - - minibatch_index = minibatch_index + 1 + #if we are using the classic learning rate deacay, adjust it before training of current mini-batch if adaptive_lr==2: classifier.lr.value = tau*initial_lr/(tau+time_n) @@ -300,17 +304,16 @@ cost_ij = train_model(x,y) if (minibatch_index+1) % validation_frequency == 0: - #save the current learning rate learning_rate_list.append(classifier.lr.value) + divergence_flag_list.append(divergence_flag) # compute the validation error this_validation_loss = 0. temp=0 for xv,yv in dataset.valid(1): # sum up the errors for each minibatch - axxa=test_model(xv,yv) - this_validation_loss += axxa + this_validation_loss += test_model(xv,yv) temp=temp+1 # get the average by dividing with the number of minibatches this_validation_loss /= temp @@ -326,9 +329,14 @@ # save best validation score and iteration number best_validation_loss = this_validation_loss best_iter = minibatch_index - # reset patience if we are going down again - # so we continue exploring - patience=nb_max_exemples/batch_size + #reset divergence flag + divergence_flag=0 + + #save the best model. Overwrite the current saved best model so + #we only keep the best + numpy.savez('best_model.npy', config=configuration, W1=classifier.W1.value, W2=classifier.W2.value, b1=classifier.b1.value,\ + b2=classifier.b2.value, minibatch_index=minibatch_index) + # test it on the test set test_score = 0. temp =0 @@ -343,21 +351,24 @@ test_score*100.)) # if the validation error is going up, we are overfitting (or oscillating) - # stop converging but run at least to next validation - # to check overfitting or ocsillation - # the saved weights of the model will be a bit off in that case + # check if we are allowed to continue and if we will adjust the learning rate elif this_validation_loss >= best_validation_loss: + + + # In non-classic learning rate decay, we modify the weight only when + # validation error is going up + if adaptive_lr==1: + classifier.lr.value=classifier.lr.value*lr_t2_factor + + + #cap the patience so we are allowed to diverge max_div_count times + #if we are going up max_div_count in a row, we will stop immediatelty by modifying the patience + divergence_flag = divergence_flag +1 + + #calculate the test error at this point and exit # test it on the test set - # however, if adaptive_lr is true, try reducing the lr to - # get us out of an oscilliation - if adaptive_lr==1: - classifier.lr.value=classifier.lr.value*lr_t2_factor - test_score = 0. - #cap the patience so we are allowed one more validation error - #calculation before aborting - patience = minibatch_index+validation_frequency+1 temp=0 for xt,yt in dataset.test(batch_size): test_score += test_model(xt,yt) @@ -372,13 +383,22 @@ - - if minibatch_index>patience: - print 'we have diverged' + # check early stop condition + if divergence_flag==max_div_count: + minibatch_index=nb_max_exemples + print 'we have diverged, early stopping kicks in' + break + + #check if we have seen enough exemples + #force one epoch at least + if epoch>0 and minibatch_index*batch_size>nb_max_exemples: break time_n= time_n + batch_size + minibatch_index = minibatch_index + 1 + + # we have finished looping through the training set epoch = epoch+1 end_time = time.clock() if verbose == 1: @@ -391,7 +411,7 @@ #save the model and the weights numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value) numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\ - learning_rate_list=learning_rate_list) + learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list) return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60) @@ -410,7 +430,8 @@ train_labels = state.train_labels,\ test_data = state.test_data,\ test_labels = state.test_labels,\ - lr_t2_factor=state.lr_t2_factor) + lr_t2_factor=state.lr_t2_factor,\ + data_set=state.data_set) state.train_error=train_error state.validation_error=validation_error state.test_error=test_error