comparison deep/convolutional_dae/sgd_opt.py @ 288:80ee63c3e749

Add net saving (only the best model) and error saving using SeriesTable
author Arnaud Bergeron <abergeron@gmail.com>
date Fri, 26 Mar 2010 17:24:17 -0400
parents 727ed56fad12
children 8babd43235dd
comparison
equal deleted inserted replaced
286:1cc535f3e254 288:80ee63c3e749
1 import time 1 import time
2 import sys 2 import sys
3 3
4 from ift6266.utils.seriestables import *
5
6 default_series = {
7 'train_error' : DummySeries(),
8 'valid_error' : DummySeries(),
9 'test_error' : DummySeries()
10 }
11
4 def sgd_opt(train, valid, test, training_epochs=10000, patience=10000, 12 def sgd_opt(train, valid, test, training_epochs=10000, patience=10000,
5 patience_increase=2., improvement_threshold=0.995, 13 patience_increase=2., improvement_threshold=0.995, net=None,
6 validation_frequency=None): 14 validation_frequency=None, series=default_series):
7 15
8 if validation_frequency is None: 16 if validation_frequency is None:
9 validation_frequency = patience/2 17 validation_frequency = patience/2
10 18
11 start_time = time.clock() 19 start_time = time.clock()
15 test_score = 0. 23 test_score = 0.
16 24
17 start_time = time.clock() 25 start_time = time.clock()
18 26
19 for epoch in xrange(1, training_epochs+1): 27 for epoch in xrange(1, training_epochs+1):
20 train() 28 series['train_error'].append((epoch,), train())
21 29
22 if epoch % validation_frequency == 0: 30 if epoch % validation_frequency == 0:
23 this_validation_loss = valid() 31 this_validation_loss = valid()
32 series['valid_error'].append((epoch,), this_validation_loss*100.)
24 print('epoch %i, validation error %f %%' % \ 33 print('epoch %i, validation error %f %%' % \
25 (epoch, this_validation_loss*100.)) 34 (epoch, this_validation_loss*100.))
26 35
27 # if we got the best validation score until now 36 # if we got the best validation score until now
28 if this_validation_loss < best_validation_loss: 37 if this_validation_loss < best_validation_loss:
36 best_validation_loss = this_validation_loss 45 best_validation_loss = this_validation_loss
37 best_epoch = epoch 46 best_epoch = epoch
38 47
39 # test it on the test set 48 # test it on the test set
40 test_score = test() 49 test_score = test()
50 series['test_error'].append((epoch,), test_score*100.)
41 print((' epoch %i, test error of best model %f %%') % 51 print((' epoch %i, test error of best model %f %%') %
42 (epoch, test_score*100.)) 52 (epoch, test_score*100.))
53 if net is not None:
54 net.save('best.net.new')
55 os.rename('best.net.new', 'best.net')
43 56
44 if patience <= epoch: 57 if patience <= epoch:
45 break 58 break
46 59
47 end_time = time.clock() 60 end_time = time.clock()