comparison baseline/mlp/mlp_nist.py @ 355:76b7182dd32e

added support for pnist in iterator. corrected a print bug in mlp
author xaviermuller
date Wed, 21 Apr 2010 15:07:09 -0400
parents 22efb4968054
children 60a4432b8071
comparison
equal deleted inserted replaced
354:ffc06af1c543 355:76b7182dd32e
21 to do lr first, then add regularization) 21 to do lr first, then add regularization)
22 22
23 """ 23 """
24 __docformat__ = 'restructedtext en' 24 __docformat__ = 'restructedtext en'
25 25
26 import sys
26 import pdb 27 import pdb
27 import numpy 28 import numpy
28 import pylab 29 import pylab
29 import theano 30 import theano
30 import theano.tensor as T 31 import theano.tensor as T
370 temp=0 371 temp=0
371 divergence_flag=0 372 divergence_flag=0
372 373
373 374
374 375
375 if verbose == 1: 376
376 print 'starting training' 377 print 'starting training'
378 sys.stdout.flush()
377 while(minibatch_index*batch_size<nb_max_exemples): 379 while(minibatch_index*batch_size<nb_max_exemples):
378 380
379 for x, y in dataset.train(batch_size): 381 for x, y in dataset.train(batch_size):
380 382
381 #if we are using the classic learning rate deacay, adjust it before training of current mini-batch 383 #if we are using the classic learning rate deacay, adjust it before training of current mini-batch
389 if (minibatch_index) % validation_frequency == 0: 391 if (minibatch_index) % validation_frequency == 0:
390 #save the current learning rate 392 #save the current learning rate
391 learning_rate_list.append(classifier.lr.value) 393 learning_rate_list.append(classifier.lr.value)
392 divergence_flag_list.append(divergence_flag) 394 divergence_flag_list.append(divergence_flag)
393 395
394 #save temp results to check during training 396
395 numpy.savez('temp_results.npy',config=configuration,total_validation_error_list=total_validation_error_list,\
396 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list)
397 397
398 # compute the validation error 398 # compute the validation error
399 this_validation_loss = 0. 399 this_validation_loss = 0.
400 temp=0 400 temp=0
401 for xv,yv in dataset.valid(1): 401 for xv,yv in dataset.valid(1):
404 temp=temp+1 404 temp=temp+1
405 # get the average by dividing with the number of minibatches 405 # get the average by dividing with the number of minibatches
406 this_validation_loss /= temp 406 this_validation_loss /= temp
407 #save the validation loss 407 #save the validation loss
408 total_validation_error_list.append(this_validation_loss) 408 total_validation_error_list.append(this_validation_loss)
409 if verbose == 1: 409
410 print(('epoch %i, minibatch %i, learning rate %f current validation error %f ') % 410 print(('epoch %i, minibatch %i, learning rate %f current validation error %f ') %
411 (epoch, minibatch_index+1,classifier.lr.value, 411 (epoch, minibatch_index+1,classifier.lr.value,
412 this_validation_loss*100.)) 412 this_validation_loss*100.))
413 sys.stdout.flush()
414
415 #save temp results to check during training
416 numpy.savez('temp_results.npy',config=configuration,total_validation_error_list=total_validation_error_list,\
417 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list)
413 418
414 # if we got the best validation score until now 419 # if we got the best validation score until now
415 if this_validation_loss < best_validation_loss: 420 if this_validation_loss < best_validation_loss:
416 # save best validation score and iteration number 421 # save best validation score and iteration number
417 best_validation_loss = this_validation_loss 422 best_validation_loss = this_validation_loss
429 temp =0 434 temp =0
430 for xt,yt in dataset.test(batch_size): 435 for xt,yt in dataset.test(batch_size):
431 test_score += test_model(xt,yt) 436 test_score += test_model(xt,yt)
432 temp = temp+1 437 temp = temp+1
433 test_score /= temp 438 test_score /= temp
434 if verbose == 1: 439
435 print(('epoch %i, minibatch %i, test error of best ' 440 print(('epoch %i, minibatch %i, test error of best '
436 'model %f %%') % 441 'model %f %%') %
437 (epoch, minibatch_index+1, 442 (epoch, minibatch_index+1,
438 test_score*100.)) 443 test_score*100.))
444 sys.stdout.flush()
439 445
440 # if the validation error is going up, we are overfitting (or oscillating) 446 # if the validation error is going up, we are overfitting (or oscillating)
441 # check if we are allowed to continue and if we will adjust the learning rate 447 # check if we are allowed to continue and if we will adjust the learning rate
442 elif this_validation_loss >= best_validation_loss: 448 elif this_validation_loss >= best_validation_loss:
443 449
459 temp=0 465 temp=0
460 for xt,yt in dataset.test(batch_size): 466 for xt,yt in dataset.test(batch_size):
461 test_score += test_model(xt,yt) 467 test_score += test_model(xt,yt)
462 temp=temp+1 468 temp=temp+1
463 test_score /= temp 469 test_score /= temp
464 if verbose == 1: 470
465 print ' validation error is going up, possibly stopping soon' 471 print ' validation error is going up, possibly stopping soon'
466 print((' epoch %i, minibatch %i, test error of best ' 472 print((' epoch %i, minibatch %i, test error of best '
467 'model %f %%') % 473 'model %f %%') %
468 (epoch, minibatch_index+1, 474 (epoch, minibatch_index+1,
469 test_score*100.)) 475 test_score*100.))
476 sys.stdout.flush()
470 477
471 478
472 479
473 # check early stop condition 480 # check early stop condition
474 if divergence_flag==max_div_count: 481 if divergence_flag==max_div_count:
489 minibatch_index = minibatch_index + 1 496 minibatch_index = minibatch_index + 1
490 497
491 # we have finished looping through the training set 498 # we have finished looping through the training set
492 epoch = epoch+1 499 epoch = epoch+1
493 end_time = time.clock() 500 end_time = time.clock()
494 if verbose == 1: 501
495 print(('Optimization complete. Best validation score of %f %% ' 502 print(('Optimization complete. Best validation score of %f %% '
496 'obtained at iteration %i, with test performance %f %%') % 503 'obtained at iteration %i, with test performance %f %%') %
497 (best_validation_loss * 100., best_iter, test_score*100.)) 504 (best_validation_loss * 100., best_iter, test_score*100.))
498 print ('The code ran for %f minutes' % ((end_time-start_time)/60.)) 505 print ('The code ran for %f minutes' % ((end_time-start_time)/60.))
499 print minibatch_index 506 print minibatch_index
507 sys.stdout.flush()
500 508
501 #save the model and the weights 509 #save the model and the weights
502 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value) 510 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value)
503 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\ 511 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\
504 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list) 512 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list)