diff deep/convolutional_dae/sgd_opt.py @ 279:206374eed2fb

Merge
author fsavard
date Wed, 24 Mar 2010 14:36:55 -0400
parents 727ed56fad12
children 80ee63c3e749
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/deep/convolutional_dae/sgd_opt.py	Wed Mar 24 14:36:55 2010 -0400
@@ -0,0 +1,52 @@
+import time
+import sys
+
+def sgd_opt(train, valid, test, training_epochs=10000, patience=10000,
+            patience_increase=2., improvement_threshold=0.995,
+            validation_frequency=None):
+
+    if validation_frequency is None:
+        validation_frequency = patience/2
+ 
+    start_time = time.clock()
+
+    best_params = None
+    best_validation_loss = float('inf')
+    test_score = 0.
+
+    start_time = time.clock()
+ 
+    for epoch in xrange(1, training_epochs+1):
+        train()
+
+        if epoch % validation_frequency == 0:
+            this_validation_loss = valid()
+            print('epoch %i, validation error %f %%' % \
+                   (epoch, this_validation_loss*100.))
+            
+            # if we got the best validation score until now
+            if this_validation_loss < best_validation_loss:
+ 
+                #improve patience if loss improvement is good enough
+                if this_validation_loss < best_validation_loss * \
+                       improvement_threshold :
+                    patience = max(patience, epoch * patience_increase)
+                
+                # save best validation score and epoch number
+                best_validation_loss = this_validation_loss
+                best_epoch = epoch
+                
+                # test it on the test set
+                test_score = test()
+                print((' epoch %i, test error of best model %f %%') %
+                      (epoch, test_score*100.))
+                
+        if patience <= epoch:
+            break
+    
+    end_time = time.clock()
+    print(('Optimization complete with best validation score of %f %%,'
+           'with test performance %f %%') %
+                 (best_validation_loss * 100., test_score*100.))
+    print ('The code ran for %f minutes' % ((end_time-start_time)/60.))
+