diff deep/convolutional_dae/sgd_opt.py @ 288:80ee63c3e749

Add net saving (only the best model) and error saving using SeriesTable
author Arnaud Bergeron <abergeron@gmail.com>
date Fri, 26 Mar 2010 17:24:17 -0400
parents 727ed56fad12
children 8babd43235dd
line wrap: on
line diff
--- a/deep/convolutional_dae/sgd_opt.py	Thu Mar 25 12:20:27 2010 -0400
+++ b/deep/convolutional_dae/sgd_opt.py	Fri Mar 26 17:24:17 2010 -0400
@@ -1,9 +1,17 @@
 import time
 import sys
 
+from ift6266.utils.seriestables import *
+
+default_series = {
+    'train_error' : DummySeries(),
+    'valid_error' : DummySeries(),
+    'test_error' : DummySeries()
+    }
+
 def sgd_opt(train, valid, test, training_epochs=10000, patience=10000,
-            patience_increase=2., improvement_threshold=0.995,
-            validation_frequency=None):
+            patience_increase=2., improvement_threshold=0.995, net=None,
+            validation_frequency=None, series=default_series):
 
     if validation_frequency is None:
         validation_frequency = patience/2
@@ -17,10 +25,11 @@
     start_time = time.clock()
  
     for epoch in xrange(1, training_epochs+1):
-        train()
+        series['train_error'].append((epoch,), train())
 
         if epoch % validation_frequency == 0:
             this_validation_loss = valid()
+            series['valid_error'].append((epoch,), this_validation_loss*100.)
             print('epoch %i, validation error %f %%' % \
                    (epoch, this_validation_loss*100.))
             
@@ -38,8 +47,12 @@
                 
                 # test it on the test set
                 test_score = test()
+                series['test_error'].append((epoch,), test_score*100.)
                 print((' epoch %i, test error of best model %f %%') %
                       (epoch, test_score*100.))
+                if net is not None:
+                    net.save('best.net.new')
+                    os.rename('best.net.new', 'best.net')
                 
         if patience <= epoch:
             break