Mercurial > ift6266
diff deep/stacked_dae/sgd_optimization.py @ 284:8a3af19ae272
Enlevé mécanique pour limiter le nombre d'exemples utilisés (remplacé par paramètre dans l'appel au code de dataset), et ajouté option pour sauvegarde des poids à la fin de l'entraînement
author | fsavard |
---|---|
date | Wed, 24 Mar 2010 15:13:48 -0400 |
parents | 7b4507295eba |
children |
line wrap: on
line diff
--- a/deep/stacked_dae/sgd_optimization.py Wed Mar 24 14:36:55 2010 -0400 +++ b/deep/stacked_dae/sgd_optimization.py Wed Mar 24 15:13:48 2010 -0400 @@ -3,6 +3,8 @@ # Generic SdA optimization loop, adapted from the deeplearning.net tutorial +from __future__ import with_statement + import numpy import theano import time @@ -25,22 +27,16 @@ 'params' : DummySeries() } -def itermax(iter, max): - for i,it in enumerate(iter): - if i >= max: - break - yield it - class SdaSgdOptimizer: def __init__(self, dataset, hyperparameters, n_ins, n_outs, - examples_per_epoch, series=default_series, max_minibatches=None): + examples_per_epoch, series=default_series, + save_params=False): self.dataset = dataset self.hp = hyperparameters self.n_ins = n_ins self.n_outs = n_outs - - self.max_minibatches = max_minibatches - print "SdaSgdOptimizer, max_minibatches =", max_minibatches + + self.save_params = save_params self.ex_per_epoch = examples_per_epoch self.mb_per_epoch = examples_per_epoch / self.hp.minibatch_size @@ -101,10 +97,6 @@ #if batch_index % 100 == 0: # print "100 batches" - # useful when doing tests - if self.max_minibatches and batch_index >= self.max_minibatches: - break - print 'Pre-training layer %i, epoch %d, cost '%(i,epoch),c sys.stdout.flush() @@ -150,8 +142,6 @@ # minibatche before checking the network # on the validation set; in this case we # check every epoch - if self.max_minibatches and validation_frequency > self.max_minibatches: - validation_frequency = self.max_minibatches / 2 best_params = None best_validation_loss = float('inf') @@ -176,8 +166,6 @@ if (total_mb_index+1) % validation_frequency == 0: iter = dataset.valid(minibatch_size) - if self.max_minibatches: - iter = itermax(iter, self.max_minibatches) validation_losses = [validate_model(x,y) for x,y in iter] this_validation_loss = numpy.mean(validation_losses) @@ -203,8 +191,6 @@ # test it on the test set iter = dataset.test(minibatch_size) - if self.max_minibatches: - iter = itermax(iter, self.max_minibatches) test_losses = [test_model(x,y) for x,y in iter] test_score = numpy.mean(test_losses) @@ -218,10 +204,6 @@ sys.stdout.flush() - # useful when doing tests - if self.max_minibatches and minibatch_index >= self.max_minibatches: - break - self.series['params'].append((epoch,), self.classifier.all_params) if patience <= total_mb_index: @@ -234,6 +216,9 @@ 'test_score':test_score, 'num_finetuning_epochs':epoch}) + if self.save_params: + save_params(self.classifier.all_params, "weights.dat") + print(('Optimization complete with best validation score of %f %%,' 'with test performance %f %%') % (best_validation_loss * 100., test_score*100.)) @@ -241,3 +226,11 @@ +def save_params(all_params, filename): + import pickle + with open(filename, 'wb') as f: + values = [p.value for p in all_params] + + # -1 for HIGHEST_PROTOCOL + pickle.dump(values, f, -1) +