Mercurial > ift6266
diff deep/convolutional_dae/sgd_opt.py @ 288:80ee63c3e749
Add net saving (only the best model) and error saving using SeriesTable
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Fri, 26 Mar 2010 17:24:17 -0400 |
parents | 727ed56fad12 |
children | 8babd43235dd |
line wrap: on
line diff
--- a/deep/convolutional_dae/sgd_opt.py Thu Mar 25 12:20:27 2010 -0400 +++ b/deep/convolutional_dae/sgd_opt.py Fri Mar 26 17:24:17 2010 -0400 @@ -1,9 +1,17 @@ import time import sys +from ift6266.utils.seriestables import * + +default_series = { + 'train_error' : DummySeries(), + 'valid_error' : DummySeries(), + 'test_error' : DummySeries() + } + def sgd_opt(train, valid, test, training_epochs=10000, patience=10000, - patience_increase=2., improvement_threshold=0.995, - validation_frequency=None): + patience_increase=2., improvement_threshold=0.995, net=None, + validation_frequency=None, series=default_series): if validation_frequency is None: validation_frequency = patience/2 @@ -17,10 +25,11 @@ start_time = time.clock() for epoch in xrange(1, training_epochs+1): - train() + series['train_error'].append((epoch,), train()) if epoch % validation_frequency == 0: this_validation_loss = valid() + series['valid_error'].append((epoch,), this_validation_loss*100.) print('epoch %i, validation error %f %%' % \ (epoch, this_validation_loss*100.)) @@ -38,8 +47,12 @@ # test it on the test set test_score = test() + series['test_error'].append((epoch,), test_score*100.) print((' epoch %i, test error of best model %f %%') % (epoch, test_score*100.)) + if net is not None: + net.save('best.net.new') + os.rename('best.net.new', 'best.net') if patience <= epoch: break