Mercurial > ift6266
changeset 338:fca22114bb23
added async save, restart from old model and independant error calculation based on Arnaud's iterator
author | xaviermuller |
---|---|
date | Sat, 17 Apr 2010 12:42:48 -0400 |
parents | 8d116d4a7593 |
children | ffbf0e41bcee 9685e9d94cc4 |
files | baseline/mlp/mlp_nist.py |
diffstat | 1 files changed, 94 insertions(+), 5 deletions(-) [+] |
line wrap: on
line diff
--- a/baseline/mlp/mlp_nist.py Fri Apr 16 16:05:55 2010 -0400 +++ b/baseline/mlp/mlp_nist.py Sat Apr 17 12:42:48 2010 -0400 @@ -163,6 +163,75 @@ else: raise NotImplementedError() +def mlp_get_nist_error(model_name='/u/mullerx/ift6266h10_sandbox_db/xvm_final_lr1_p073/8/best_model.npy.npz', + data_set=0): + + + + # allocate symbolic variables for the data + x = T.fmatrix() # the data is presented as rasterized images + y = T.lvector() # the labels are presented as 1D vector of + # [long int] labels + + # load the data set and create an mlp based on the dimensions of the model + model=numpy.load(model_name) + W1=model['W1'] + W2=model['W2'] + b1=model['b1'] + b2=model['b2'] + nb_hidden=b1.shape[0] + input_dim=W1.shape[0] + nb_targets=b2.shape[0] + learning_rate=0.1 + + + if data_set==0: + dataset=datasets.nist_all() + elif data_set==1: + dataset=datasets.nist_P07() + + + classifier = MLP( input=x,\ + n_in=input_dim,\ + n_hidden=nb_hidden,\ + n_out=nb_targets, + learning_rate=learning_rate) + + + #overwrite weights with weigths from model + classifier.W1.value=W1 + classifier.W2.value=W2 + classifier.b1.value=b1 + classifier.b2.value=b2 + + + cost = classifier.negative_log_likelihood(y) \ + + 0.0 * classifier.L1 \ + + 0.0 * classifier.L2_sqr + + # compiling a theano function that computes the mistakes that are made by + # the model on a minibatch + test_model = theano.function([x,y], classifier.errors(y)) + + + + #get the test error + #use a batch size of 1 so we can get the sub-class error + #without messing with matrices (will be upgraded later) + test_score=0 + temp=0 + for xt,yt in dataset.test(20): + test_score += test_model(xt,yt) + temp = temp+1 + test_score /= temp + + + return test_score*100 + + + + + def mlp_full_nist( verbose = 1,\ adaptive_lr = 0,\ @@ -174,15 +243,19 @@ batch_size=20,\ nb_hidden = 30,\ nb_targets = 62, - tau=1e6,\ - lr_t2_factor=0.5): + tau=1e6,\ + lr_t2_factor=0.5,\ + init_model=0,\ + channel=0): + if channel!=0: + channel.save() configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr] #save initial learning rate if classical adaptive lr is used initial_lr=learning_rate - max_div_count=3 + max_div_count=1000 total_validation_error_list = [] @@ -215,6 +288,14 @@ learning_rate=learning_rate) + # check if we want to initialise the weights with a previously calculated model + # dimensions must be consistent between old model and current configuration!!!!!! (nb_hidden and nb_targets) + if init_model!=0: + old_model=numpy.load(init_model) + classifier.W1.value=old_model['W1'] + classifier.W2.value=old_model['W2'] + classifier.b1.value=old_model['b1'] + classifier.b2.value=old_model['b2'] # the cost we minimize during training is the negative log likelihood of @@ -303,10 +384,14 @@ #train model cost_ij = train_model(x,y) - if (minibatch_index+1) % validation_frequency == 0: + if (minibatch_index) % validation_frequency == 0: #save the current learning rate learning_rate_list.append(classifier.lr.value) divergence_flag_list.append(divergence_flag) + + #save temp results to check during training + numpy.savez('temp_results.npy',config=configuration,total_validation_error_list=total_validation_error_list,\ + learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list) # compute the validation error this_validation_loss = 0. @@ -393,6 +478,9 @@ #force one epoch at least if epoch>0 and minibatch_index*batch_size>nb_max_exemples: break + + + time_n= time_n + batch_size @@ -427,7 +515,8 @@ tau=state.tau,\ verbose = state.verbose,\ lr_t2_factor=state.lr_t2_factor, - data_set=state.data_set) + data_set=state.data_set, + channel=channel) state.train_error=train_error state.validation_error=validation_error state.test_error=test_error