Mercurial > ift6266
comparison deep/stacked_dae/v2/sgd_optimization.py @ 240:f213a0fb2b08
Corrigé bugs dans stacked_dae/v2 en rapport avec les modifs d'hier
author | fsavard |
---|---|
date | Tue, 16 Mar 2010 10:52:04 -0400 |
parents | 42005ec87747 |
children | 8a00764ea8a4 |
comparison
equal
deleted
inserted
replaced
239:42005ec87747 | 240:f213a0fb2b08 |
---|---|
27 | 27 |
28 def itermax(iter, max): | 28 def itermax(iter, max): |
29 for i,it in enumerate(iter): | 29 for i,it in enumerate(iter): |
30 if i >= max: | 30 if i >= max: |
31 break | 31 break |
32 yield i | 32 yield it |
33 | 33 |
34 class SdaSgdOptimizer: | 34 class SdaSgdOptimizer: |
35 def __init__(self, dataset, hyperparameters, n_ins, n_outs, | 35 def __init__(self, dataset, hyperparameters, n_ins, n_outs, |
36 examples_per_epoch, series=default_series, max_minibatches=None): | 36 examples_per_epoch, series=default_series, max_minibatches=None): |
37 self.dataset = dataset | 37 self.dataset = dataset |
96 c = self.classifier.pretrain_functions[i](x) | 96 c = self.classifier.pretrain_functions[i](x) |
97 | 97 |
98 self.series["reconstruction_error"].append((epoch, batch_index), c) | 98 self.series["reconstruction_error"].append((epoch, batch_index), c) |
99 batch_index+=1 | 99 batch_index+=1 |
100 | 100 |
101 if batch_index % 10000 == 0: | 101 if batch_index % 100 == 0: |
102 print "10000 batches" | 102 print "100 batches" |
103 | 103 |
104 # useful when doing tests | 104 # useful when doing tests |
105 if self.max_minibatches and batch_index >= self.max_minibatches: | 105 if self.max_minibatches and batch_index >= self.max_minibatches: |
106 break | 106 break |
107 | 107 |
148 validation_frequency = min(self.mb_per_epoch, patience/2) | 148 validation_frequency = min(self.mb_per_epoch, patience/2) |
149 # go through this many | 149 # go through this many |
150 # minibatche before checking the network | 150 # minibatche before checking the network |
151 # on the validation set; in this case we | 151 # on the validation set; in this case we |
152 # check every epoch | 152 # check every epoch |
153 if self.max_minibatches and validation_frequency > self.max_minibatches: | |
154 validation_frequency = self.max_minibatches / 2 | |
153 | 155 |
154 best_params = None | 156 best_params = None |
155 best_validation_loss = float('inf') | 157 best_validation_loss = float('inf') |
156 test_score = 0. | 158 test_score = 0. |
157 start_time = time.clock() | 159 start_time = time.clock() |
181 | 183 |
182 self.series["validation_error"].\ | 184 self.series["validation_error"].\ |
183 append((epoch, minibatch_index), this_validation_loss*100.) | 185 append((epoch, minibatch_index), this_validation_loss*100.) |
184 | 186 |
185 print('epoch %i, minibatch %i/%i, validation error %f %%' % \ | 187 print('epoch %i, minibatch %i/%i, validation error %f %%' % \ |
186 (epoch, minibatch_index+1, self.n_train_batches, \ | 188 (epoch, minibatch_index+1, self.mb_per_epoch, \ |
187 this_validation_loss*100.)) | 189 this_validation_loss*100.)) |
188 | 190 |
189 | 191 |
190 # if we got the best validation score until now | 192 # if we got the best validation score until now |
191 if this_validation_loss < best_validation_loss: | 193 if this_validation_loss < best_validation_loss: |
209 self.series["test_error"].\ | 211 self.series["test_error"].\ |
210 append((epoch, minibatch_index), test_score*100.) | 212 append((epoch, minibatch_index), test_score*100.) |
211 | 213 |
212 print((' epoch %i, minibatch %i/%i, test error of best ' | 214 print((' epoch %i, minibatch %i/%i, test error of best ' |
213 'model %f %%') % | 215 'model %f %%') % |
214 (epoch, minibatch_index+1, self.n_train_batches, | 216 (epoch, minibatch_index+1, self.mb_per_epoch, |
215 test_score*100.)) | 217 test_score*100.)) |
216 | 218 |
217 sys.stdout.flush() | 219 sys.stdout.flush() |
218 | 220 |
219 # useful when doing tests | 221 # useful when doing tests |
220 if self.max_minibatches and batch_index >= self.max_minibatches: | 222 if self.max_minibatches and minibatch_index >= self.max_minibatches: |
221 break | 223 break |
222 | 224 |
223 self.series['params'].append((epoch,), self.classifier.all_params) | 225 self.series['params'].append((epoch,), self.classifier.all_params) |
224 | 226 |
225 if patience <= total_mb_index: | 227 if patience <= total_mb_index: |