Mercurial > ift6266
comparison deep/convolutional_dae/sgd_opt.py @ 288:80ee63c3e749
Add net saving (only the best model) and error saving using SeriesTable
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Fri, 26 Mar 2010 17:24:17 -0400 |
parents | 727ed56fad12 |
children | 8babd43235dd |
comparison
equal
deleted
inserted
replaced
286:1cc535f3e254 | 288:80ee63c3e749 |
---|---|
1 import time | 1 import time |
2 import sys | 2 import sys |
3 | 3 |
4 from ift6266.utils.seriestables import * | |
5 | |
6 default_series = { | |
7 'train_error' : DummySeries(), | |
8 'valid_error' : DummySeries(), | |
9 'test_error' : DummySeries() | |
10 } | |
11 | |
4 def sgd_opt(train, valid, test, training_epochs=10000, patience=10000, | 12 def sgd_opt(train, valid, test, training_epochs=10000, patience=10000, |
5 patience_increase=2., improvement_threshold=0.995, | 13 patience_increase=2., improvement_threshold=0.995, net=None, |
6 validation_frequency=None): | 14 validation_frequency=None, series=default_series): |
7 | 15 |
8 if validation_frequency is None: | 16 if validation_frequency is None: |
9 validation_frequency = patience/2 | 17 validation_frequency = patience/2 |
10 | 18 |
11 start_time = time.clock() | 19 start_time = time.clock() |
15 test_score = 0. | 23 test_score = 0. |
16 | 24 |
17 start_time = time.clock() | 25 start_time = time.clock() |
18 | 26 |
19 for epoch in xrange(1, training_epochs+1): | 27 for epoch in xrange(1, training_epochs+1): |
20 train() | 28 series['train_error'].append((epoch,), train()) |
21 | 29 |
22 if epoch % validation_frequency == 0: | 30 if epoch % validation_frequency == 0: |
23 this_validation_loss = valid() | 31 this_validation_loss = valid() |
32 series['valid_error'].append((epoch,), this_validation_loss*100.) | |
24 print('epoch %i, validation error %f %%' % \ | 33 print('epoch %i, validation error %f %%' % \ |
25 (epoch, this_validation_loss*100.)) | 34 (epoch, this_validation_loss*100.)) |
26 | 35 |
27 # if we got the best validation score until now | 36 # if we got the best validation score until now |
28 if this_validation_loss < best_validation_loss: | 37 if this_validation_loss < best_validation_loss: |
36 best_validation_loss = this_validation_loss | 45 best_validation_loss = this_validation_loss |
37 best_epoch = epoch | 46 best_epoch = epoch |
38 | 47 |
39 # test it on the test set | 48 # test it on the test set |
40 test_score = test() | 49 test_score = test() |
50 series['test_error'].append((epoch,), test_score*100.) | |
41 print((' epoch %i, test error of best model %f %%') % | 51 print((' epoch %i, test error of best model %f %%') % |
42 (epoch, test_score*100.)) | 52 (epoch, test_score*100.)) |
53 if net is not None: | |
54 net.save('best.net.new') | |
55 os.rename('best.net.new', 'best.net') | |
43 | 56 |
44 if patience <= epoch: | 57 if patience <= epoch: |
45 break | 58 break |
46 | 59 |
47 end_time = time.clock() | 60 end_time = time.clock() |