comparison deep/stacked_dae/v_sylvain/sgd_optimization.py @ 283:28b628f331b2

correction d'un bug sur l'indice des mini-batches
author SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca>
date Wed, 24 Mar 2010 14:58:58 -0400
parents a8b92a4a708d
children 1cc535f3e254
comparison
equal deleted inserted replaced
282:698313f8f6e6 283:28b628f331b2
193 193
194 done_looping = False 194 done_looping = False
195 epoch = 0 195 epoch = 0
196 196
197 total_mb_index = 0 197 total_mb_index = 0
198 minibatch_index = -1
198 199
199 while (epoch < num_finetune) and (not done_looping): 200 while (epoch < num_finetune) and (not done_looping):
200 epoch = epoch + 1 201 epoch = epoch + 1
201 minibatch_index = -1 202
202 for x,y in dataset.train(minibatch_size): 203 for x,y in dataset.train(minibatch_size):
203 minibatch_index += 1 204 minibatch_index += 1
204 if special == 0: 205 if special == 0:
205 cost_ij = self.classifier.finetune(x,y) 206 cost_ij = self.classifier.finetune(x,y)
206 elif special == 1: 207 elif special == 1:
208 total_mb_index += 1 209 total_mb_index += 1
209 210
210 self.series["training_error"].append((epoch, minibatch_index), cost_ij) 211 self.series["training_error"].append((epoch, minibatch_index), cost_ij)
211 212
212 if (total_mb_index+1) % validation_frequency == 0: 213 if (total_mb_index+1) % validation_frequency == 0:
213 214 #minibatch_index += 1
214 #The validation set is always NIST 215 #The validation set is always NIST
215 if ind_test == 0: 216 if ind_test == 0:
216 iter=dataset_test.valid(minibatch_size) 217 iter=dataset_test.valid(minibatch_size)
217 else: 218 else:
218 iter = dataset.valid(minibatch_size) 219 iter = dataset.valid(minibatch_size)