Mercurial > ift6266
changeset 240:f213a0fb2b08
Corrigé bugs dans stacked_dae/v2 en rapport avec les modifs d'hier
author | fsavard |
---|---|
date | Tue, 16 Mar 2010 10:52:04 -0400 |
parents | 42005ec87747 |
children | c24020aa38ac 8a00764ea8a4 |
files | deep/stacked_dae/v2/sgd_optimization.py |
diffstat | 1 files changed, 8 insertions(+), 6 deletions(-) [+] |
line wrap: on
line diff
--- a/deep/stacked_dae/v2/sgd_optimization.py Mon Mar 15 18:30:21 2010 -0400 +++ b/deep/stacked_dae/v2/sgd_optimization.py Tue Mar 16 10:52:04 2010 -0400 @@ -29,7 +29,7 @@ for i,it in enumerate(iter): if i >= max: break - yield i + yield it class SdaSgdOptimizer: def __init__(self, dataset, hyperparameters, n_ins, n_outs, @@ -98,8 +98,8 @@ self.series["reconstruction_error"].append((epoch, batch_index), c) batch_index+=1 - if batch_index % 10000 == 0: - print "10000 batches" + if batch_index % 100 == 0: + print "100 batches" # useful when doing tests if self.max_minibatches and batch_index >= self.max_minibatches: @@ -150,6 +150,8 @@ # minibatche before checking the network # on the validation set; in this case we # check every epoch + if self.max_minibatches and validation_frequency > self.max_minibatches: + validation_frequency = self.max_minibatches / 2 best_params = None best_validation_loss = float('inf') @@ -183,7 +185,7 @@ append((epoch, minibatch_index), this_validation_loss*100.) print('epoch %i, minibatch %i/%i, validation error %f %%' % \ - (epoch, minibatch_index+1, self.n_train_batches, \ + (epoch, minibatch_index+1, self.mb_per_epoch, \ this_validation_loss*100.)) @@ -211,13 +213,13 @@ print((' epoch %i, minibatch %i/%i, test error of best ' 'model %f %%') % - (epoch, minibatch_index+1, self.n_train_batches, + (epoch, minibatch_index+1, self.mb_per_epoch, test_score*100.)) sys.stdout.flush() # useful when doing tests - if self.max_minibatches and batch_index >= self.max_minibatches: + if self.max_minibatches and minibatch_index >= self.max_minibatches: break self.series['params'].append((epoch,), self.classifier.all_params)