comparison deep/stacked_dae/v2/sgd_optimization.py @ 240:f213a0fb2b08

Corrigé bugs dans stacked_dae/v2 en rapport avec les modifs d'hier
author fsavard
date Tue, 16 Mar 2010 10:52:04 -0400
parents 42005ec87747
children 8a00764ea8a4
comparison
equal deleted inserted replaced
239:42005ec87747 240:f213a0fb2b08
27 27
28 def itermax(iter, max): 28 def itermax(iter, max):
29 for i,it in enumerate(iter): 29 for i,it in enumerate(iter):
30 if i >= max: 30 if i >= max:
31 break 31 break
32 yield i 32 yield it
33 33
34 class SdaSgdOptimizer: 34 class SdaSgdOptimizer:
35 def __init__(self, dataset, hyperparameters, n_ins, n_outs, 35 def __init__(self, dataset, hyperparameters, n_ins, n_outs,
36 examples_per_epoch, series=default_series, max_minibatches=None): 36 examples_per_epoch, series=default_series, max_minibatches=None):
37 self.dataset = dataset 37 self.dataset = dataset
96 c = self.classifier.pretrain_functions[i](x) 96 c = self.classifier.pretrain_functions[i](x)
97 97
98 self.series["reconstruction_error"].append((epoch, batch_index), c) 98 self.series["reconstruction_error"].append((epoch, batch_index), c)
99 batch_index+=1 99 batch_index+=1
100 100
101 if batch_index % 10000 == 0: 101 if batch_index % 100 == 0:
102 print "10000 batches" 102 print "100 batches"
103 103
104 # useful when doing tests 104 # useful when doing tests
105 if self.max_minibatches and batch_index >= self.max_minibatches: 105 if self.max_minibatches and batch_index >= self.max_minibatches:
106 break 106 break
107 107
148 validation_frequency = min(self.mb_per_epoch, patience/2) 148 validation_frequency = min(self.mb_per_epoch, patience/2)
149 # go through this many 149 # go through this many
150 # minibatche before checking the network 150 # minibatche before checking the network
151 # on the validation set; in this case we 151 # on the validation set; in this case we
152 # check every epoch 152 # check every epoch
153 if self.max_minibatches and validation_frequency > self.max_minibatches:
154 validation_frequency = self.max_minibatches / 2
153 155
154 best_params = None 156 best_params = None
155 best_validation_loss = float('inf') 157 best_validation_loss = float('inf')
156 test_score = 0. 158 test_score = 0.
157 start_time = time.clock() 159 start_time = time.clock()
181 183
182 self.series["validation_error"].\ 184 self.series["validation_error"].\
183 append((epoch, minibatch_index), this_validation_loss*100.) 185 append((epoch, minibatch_index), this_validation_loss*100.)
184 186
185 print('epoch %i, minibatch %i/%i, validation error %f %%' % \ 187 print('epoch %i, minibatch %i/%i, validation error %f %%' % \
186 (epoch, minibatch_index+1, self.n_train_batches, \ 188 (epoch, minibatch_index+1, self.mb_per_epoch, \
187 this_validation_loss*100.)) 189 this_validation_loss*100.))
188 190
189 191
190 # if we got the best validation score until now 192 # if we got the best validation score until now
191 if this_validation_loss < best_validation_loss: 193 if this_validation_loss < best_validation_loss:
209 self.series["test_error"].\ 211 self.series["test_error"].\
210 append((epoch, minibatch_index), test_score*100.) 212 append((epoch, minibatch_index), test_score*100.)
211 213
212 print((' epoch %i, minibatch %i/%i, test error of best ' 214 print((' epoch %i, minibatch %i/%i, test error of best '
213 'model %f %%') % 215 'model %f %%') %
214 (epoch, minibatch_index+1, self.n_train_batches, 216 (epoch, minibatch_index+1, self.mb_per_epoch,
215 test_score*100.)) 217 test_score*100.))
216 218
217 sys.stdout.flush() 219 sys.stdout.flush()
218 220
219 # useful when doing tests 221 # useful when doing tests
220 if self.max_minibatches and batch_index >= self.max_minibatches: 222 if self.max_minibatches and minibatch_index >= self.max_minibatches:
221 break 223 break
222 224
223 self.series['params'].append((epoch,), self.classifier.all_params) 225 self.series['params'].append((epoch,), self.classifier.all_params)
224 226
225 if patience <= total_mb_index: 227 if patience <= total_mb_index: