comparison baseline/mlp/mlp_nist.py @ 323:7a7615f940e8

finished code clean up and testing
author xaviermuller
date Thu, 08 Apr 2010 11:01:55 -0400
parents 743907366476
children 1763c64030d1
comparison
equal deleted inserted replaced
322:743907366476 323:7a7615f940e8
180 180
181 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr] 181 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr]
182 182
183 #save initial learning rate if classical adaptive lr is used 183 #save initial learning rate if classical adaptive lr is used
184 initial_lr=learning_rate 184 initial_lr=learning_rate
185 max_div_count=3
186
185 187
186 total_validation_error_list = [] 188 total_validation_error_list = []
187 total_train_error_list = [] 189 total_train_error_list = []
188 learning_rate_list=[] 190 learning_rate_list=[]
189 best_training_error=float('inf'); 191 best_training_error=float('inf');
192 divergence_flag_list=[]
190 193
191 if data_set==0: 194 if data_set==0:
192 dataset=datasets.nist_all() 195 dataset=datasets.nist_all()
196 elif data_set==1:
197 dataset=datasets.nist_P07()
193 198
194 199
195 200
196 201
197 ishape = (32,32) # this is the size of NIST images 202 ishape = (32,32) # this is the size of NIST images
248 253
249 254
250 255
251 256
252 #conditions for stopping the adaptation: 257 #conditions for stopping the adaptation:
253 #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size) 258 #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size so we always do at least 1 epoch)
254 #2) validation error is going up twice in a row(probable overfitting) 259 #2) validation error is going up twice in a row(probable overfitting)
255 260
256 # This means we no longer stop on slow convergence as low learning rates stopped 261 # This means we no longer stop on slow convergence as low learning rates stopped
257 # too fast. 262 # too fast but instead we will wait for the valid error going up 3 times in a row
258 263 # We save the curb of the validation error so we can always go back to check on it
259 #approximate number of samples in the training set 264 # and we save the absolute best model anyway, so we might as well explore
265 # a bit when diverging
266
267 #approximate number of samples in the nist training set
260 #this is just to have a validation frequency 268 #this is just to have a validation frequency
261 #roughly proportionnal to the training set 269 #roughly proportionnal to the original nist training set
262 n_minibatches = 650000/batch_size 270 n_minibatches = 650000/batch_size
263 271
264 272
265 patience =nb_max_exemples/batch_size #in units of minibatch 273 patience =2*nb_max_exemples/batch_size #in units of minibatch
266 patience_increase = 2 # wait this much longer when a new best is
267 # found
268 improvement_threshold = 0.995 # a relative improvement of this much is
269 # considered significant
270 validation_frequency = n_minibatches/4 274 validation_frequency = n_minibatches/4
271 275
272 276
273 277
274 278
279 start_time = time.clock() 283 start_time = time.clock()
280 time_n=0 #in unit of exemples 284 time_n=0 #in unit of exemples
281 minibatch_index=0 285 minibatch_index=0
282 epoch=0 286 epoch=0
283 temp=0 287 temp=0
288 divergence_flag=0
284 289
285 290
286 291
287 if verbose == 1: 292 if verbose == 1:
288 print 'looking at most at %i exemples' %nb_max_exemples 293 print 'starting training'
289 while(minibatch_index*batch_size<nb_max_exemples): 294 while(minibatch_index*batch_size<nb_max_exemples):
290 295
291 for x, y in dataset.train(batch_size): 296 for x, y in dataset.train(batch_size):
292 297
293 298 #if we are using the classic learning rate deacay, adjust it before training of current mini-batch
294 minibatch_index = minibatch_index + 1
295 if adaptive_lr==2: 299 if adaptive_lr==2:
296 classifier.lr.value = tau*initial_lr/(tau+time_n) 300 classifier.lr.value = tau*initial_lr/(tau+time_n)
297 301
298 302
299 #train model 303 #train model
300 cost_ij = train_model(x,y) 304 cost_ij = train_model(x,y)
301 305
302 if (minibatch_index+1) % validation_frequency == 0: 306 if (minibatch_index+1) % validation_frequency == 0:
303
304 #save the current learning rate 307 #save the current learning rate
305 learning_rate_list.append(classifier.lr.value) 308 learning_rate_list.append(classifier.lr.value)
309 divergence_flag_list.append(divergence_flag)
306 310
307 # compute the validation error 311 # compute the validation error
308 this_validation_loss = 0. 312 this_validation_loss = 0.
309 temp=0 313 temp=0
310 for xv,yv in dataset.valid(1): 314 for xv,yv in dataset.valid(1):
311 # sum up the errors for each minibatch 315 # sum up the errors for each minibatch
312 axxa=test_model(xv,yv) 316 this_validation_loss += test_model(xv,yv)
313 this_validation_loss += axxa
314 temp=temp+1 317 temp=temp+1
315 # get the average by dividing with the number of minibatches 318 # get the average by dividing with the number of minibatches
316 this_validation_loss /= temp 319 this_validation_loss /= temp
317 #save the validation loss 320 #save the validation loss
318 total_validation_error_list.append(this_validation_loss) 321 total_validation_error_list.append(this_validation_loss)
324 # if we got the best validation score until now 327 # if we got the best validation score until now
325 if this_validation_loss < best_validation_loss: 328 if this_validation_loss < best_validation_loss:
326 # save best validation score and iteration number 329 # save best validation score and iteration number
327 best_validation_loss = this_validation_loss 330 best_validation_loss = this_validation_loss
328 best_iter = minibatch_index 331 best_iter = minibatch_index
329 # reset patience if we are going down again 332 #reset divergence flag
330 # so we continue exploring 333 divergence_flag=0
331 patience=nb_max_exemples/batch_size 334
335 #save the best model. Overwrite the current saved best model so
336 #we only keep the best
337 numpy.savez('best_model.npy', config=configuration, W1=classifier.W1.value, W2=classifier.W2.value, b1=classifier.b1.value,\
338 b2=classifier.b2.value, minibatch_index=minibatch_index)
339
332 # test it on the test set 340 # test it on the test set
333 test_score = 0. 341 test_score = 0.
334 temp =0 342 temp =0
335 for xt,yt in dataset.test(batch_size): 343 for xt,yt in dataset.test(batch_size):
336 test_score += test_model(xt,yt) 344 test_score += test_model(xt,yt)
341 'model %f %%') % 349 'model %f %%') %
342 (epoch, minibatch_index+1, 350 (epoch, minibatch_index+1,
343 test_score*100.)) 351 test_score*100.))
344 352
345 # if the validation error is going up, we are overfitting (or oscillating) 353 # if the validation error is going up, we are overfitting (or oscillating)
346 # stop converging but run at least to next validation 354 # check if we are allowed to continue and if we will adjust the learning rate
347 # to check overfitting or ocsillation
348 # the saved weights of the model will be a bit off in that case
349 elif this_validation_loss >= best_validation_loss: 355 elif this_validation_loss >= best_validation_loss:
356
357
358 # In non-classic learning rate decay, we modify the weight only when
359 # validation error is going up
360 if adaptive_lr==1:
361 classifier.lr.value=classifier.lr.value*lr_t2_factor
362
363
364 #cap the patience so we are allowed to diverge max_div_count times
365 #if we are going up max_div_count in a row, we will stop immediatelty by modifying the patience
366 divergence_flag = divergence_flag +1
367
368
350 #calculate the test error at this point and exit 369 #calculate the test error at this point and exit
351 # test it on the test set 370 # test it on the test set
352 # however, if adaptive_lr is true, try reducing the lr to
353 # get us out of an oscilliation
354 if adaptive_lr==1:
355 classifier.lr.value=classifier.lr.value*lr_t2_factor
356
357 test_score = 0. 371 test_score = 0.
358 #cap the patience so we are allowed one more validation error
359 #calculation before aborting
360 patience = minibatch_index+validation_frequency+1
361 temp=0 372 temp=0
362 for xt,yt in dataset.test(batch_size): 373 for xt,yt in dataset.test(batch_size):
363 test_score += test_model(xt,yt) 374 test_score += test_model(xt,yt)
364 temp=temp+1 375 temp=temp+1
365 test_score /= temp 376 test_score /= temp
370 (epoch, minibatch_index+1, 381 (epoch, minibatch_index+1,
371 test_score*100.)) 382 test_score*100.))
372 383
373 384
374 385
375 386 # check early stop condition
376 if minibatch_index>patience: 387 if divergence_flag==max_div_count:
377 print 'we have diverged' 388 minibatch_index=nb_max_exemples
389 print 'we have diverged, early stopping kicks in'
378 break 390 break
391
392 #check if we have seen enough exemples
393 #force one epoch at least
394 if epoch>0 and minibatch_index*batch_size>nb_max_exemples:
395 break
379 396
380 397
381 time_n= time_n + batch_size 398 time_n= time_n + batch_size
399 minibatch_index = minibatch_index + 1
400
401 # we have finished looping through the training set
382 epoch = epoch+1 402 epoch = epoch+1
383 end_time = time.clock() 403 end_time = time.clock()
384 if verbose == 1: 404 if verbose == 1:
385 print(('Optimization complete. Best validation score of %f %% ' 405 print(('Optimization complete. Best validation score of %f %% '
386 'obtained at iteration %i, with test performance %f %%') % 406 'obtained at iteration %i, with test performance %f %%') %
389 print minibatch_index 409 print minibatch_index
390 410
391 #save the model and the weights 411 #save the model and the weights
392 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value) 412 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value)
393 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\ 413 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\
394 learning_rate_list=learning_rate_list) 414 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list)
395 415
396 return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60) 416 return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60)
397 417
398 418
399 if __name__ == '__main__': 419 if __name__ == '__main__':
408 verbose = state.verbose,\ 428 verbose = state.verbose,\
409 train_data = state.train_data,\ 429 train_data = state.train_data,\
410 train_labels = state.train_labels,\ 430 train_labels = state.train_labels,\
411 test_data = state.test_data,\ 431 test_data = state.test_data,\
412 test_labels = state.test_labels,\ 432 test_labels = state.test_labels,\
413 lr_t2_factor=state.lr_t2_factor) 433 lr_t2_factor=state.lr_t2_factor,\
434 data_set=state.data_set)
414 state.train_error=train_error 435 state.train_error=train_error
415 state.validation_error=validation_error 436 state.validation_error=validation_error
416 state.test_error=test_error 437 state.test_error=test_error
417 state.nb_exemples=nb_exemples 438 state.nb_exemples=nb_exemples
418 state.time=time 439 state.time=time