comparison deep/stacked_dae/sgd_optimization.py @ 284:8a3af19ae272

Enlevé mécanique pour limiter le nombre d'exemples utilisés (remplacé par paramètre dans l'appel au code de dataset), et ajouté option pour sauvegarde des poids à la fin de l'entraînement
author fsavard
date Wed, 24 Mar 2010 15:13:48 -0400
parents 7b4507295eba
children
comparison
equal deleted inserted replaced
279:206374eed2fb 284:8a3af19ae272
1 #!/usr/bin/python 1 #!/usr/bin/python
2 # coding: utf-8 2 # coding: utf-8
3 3
4 # Generic SdA optimization loop, adapted from the deeplearning.net tutorial 4 # Generic SdA optimization loop, adapted from the deeplearning.net tutorial
5
6 from __future__ import with_statement
5 7
6 import numpy 8 import numpy
7 import theano 9 import theano
8 import time 10 import time
9 import datetime 11 import datetime
23 'validation_error' : DummySeries(), 25 'validation_error' : DummySeries(),
24 'test_error' : DummySeries(), 26 'test_error' : DummySeries(),
25 'params' : DummySeries() 27 'params' : DummySeries()
26 } 28 }
27 29
28 def itermax(iter, max):
29 for i,it in enumerate(iter):
30 if i >= max:
31 break
32 yield it
33
34 class SdaSgdOptimizer: 30 class SdaSgdOptimizer:
35 def __init__(self, dataset, hyperparameters, n_ins, n_outs, 31 def __init__(self, dataset, hyperparameters, n_ins, n_outs,
36 examples_per_epoch, series=default_series, max_minibatches=None): 32 examples_per_epoch, series=default_series,
33 save_params=False):
37 self.dataset = dataset 34 self.dataset = dataset
38 self.hp = hyperparameters 35 self.hp = hyperparameters
39 self.n_ins = n_ins 36 self.n_ins = n_ins
40 self.n_outs = n_outs 37 self.n_outs = n_outs
41 38
42 self.max_minibatches = max_minibatches 39 self.save_params = save_params
43 print "SdaSgdOptimizer, max_minibatches =", max_minibatches
44 40
45 self.ex_per_epoch = examples_per_epoch 41 self.ex_per_epoch = examples_per_epoch
46 self.mb_per_epoch = examples_per_epoch / self.hp.minibatch_size 42 self.mb_per_epoch = examples_per_epoch / self.hp.minibatch_size
47 43
48 self.series = series 44 self.series = series
99 batch_index+=1 95 batch_index+=1
100 96
101 #if batch_index % 100 == 0: 97 #if batch_index % 100 == 0:
102 # print "100 batches" 98 # print "100 batches"
103 99
104 # useful when doing tests
105 if self.max_minibatches and batch_index >= self.max_minibatches:
106 break
107
108 print 'Pre-training layer %i, epoch %d, cost '%(i,epoch),c 100 print 'Pre-training layer %i, epoch %d, cost '%(i,epoch),c
109 sys.stdout.flush() 101 sys.stdout.flush()
110 102
111 self.series['params'].append((epoch,), self.classifier.all_params) 103 self.series['params'].append((epoch,), self.classifier.all_params)
112 104
148 validation_frequency = min(self.mb_per_epoch, patience/2) 140 validation_frequency = min(self.mb_per_epoch, patience/2)
149 # go through this many 141 # go through this many
150 # minibatche before checking the network 142 # minibatche before checking the network
151 # on the validation set; in this case we 143 # on the validation set; in this case we
152 # check every epoch 144 # check every epoch
153 if self.max_minibatches and validation_frequency > self.max_minibatches:
154 validation_frequency = self.max_minibatches / 2
155 145
156 best_params = None 146 best_params = None
157 best_validation_loss = float('inf') 147 best_validation_loss = float('inf')
158 test_score = 0. 148 test_score = 0.
159 start_time = time.clock() 149 start_time = time.clock()
174 self.series["training_error"].append((epoch, minibatch_index), cost_ij) 164 self.series["training_error"].append((epoch, minibatch_index), cost_ij)
175 165
176 if (total_mb_index+1) % validation_frequency == 0: 166 if (total_mb_index+1) % validation_frequency == 0:
177 167
178 iter = dataset.valid(minibatch_size) 168 iter = dataset.valid(minibatch_size)
179 if self.max_minibatches:
180 iter = itermax(iter, self.max_minibatches)
181 validation_losses = [validate_model(x,y) for x,y in iter] 169 validation_losses = [validate_model(x,y) for x,y in iter]
182 this_validation_loss = numpy.mean(validation_losses) 170 this_validation_loss = numpy.mean(validation_losses)
183 171
184 self.series["validation_error"].\ 172 self.series["validation_error"].\
185 append((epoch, minibatch_index), this_validation_loss*100.) 173 append((epoch, minibatch_index), this_validation_loss*100.)
201 best_validation_loss = this_validation_loss 189 best_validation_loss = this_validation_loss
202 best_iter = total_mb_index 190 best_iter = total_mb_index
203 191
204 # test it on the test set 192 # test it on the test set
205 iter = dataset.test(minibatch_size) 193 iter = dataset.test(minibatch_size)
206 if self.max_minibatches:
207 iter = itermax(iter, self.max_minibatches)
208 test_losses = [test_model(x,y) for x,y in iter] 194 test_losses = [test_model(x,y) for x,y in iter]
209 test_score = numpy.mean(test_losses) 195 test_score = numpy.mean(test_losses)
210 196
211 self.series["test_error"].\ 197 self.series["test_error"].\
212 append((epoch, minibatch_index), test_score*100.) 198 append((epoch, minibatch_index), test_score*100.)
215 'model %f %%') % 201 'model %f %%') %
216 (epoch, minibatch_index+1, self.mb_per_epoch, 202 (epoch, minibatch_index+1, self.mb_per_epoch,
217 test_score*100.)) 203 test_score*100.))
218 204
219 sys.stdout.flush() 205 sys.stdout.flush()
220
221 # useful when doing tests
222 if self.max_minibatches and minibatch_index >= self.max_minibatches:
223 break
224 206
225 self.series['params'].append((epoch,), self.classifier.all_params) 207 self.series['params'].append((epoch,), self.classifier.all_params)
226 208
227 if patience <= total_mb_index: 209 if patience <= total_mb_index:
228 done_looping = True 210 done_looping = True
232 self.hp.update({'finetuning_time':end_time-start_time,\ 214 self.hp.update({'finetuning_time':end_time-start_time,\
233 'best_validation_error':best_validation_loss,\ 215 'best_validation_error':best_validation_loss,\
234 'test_score':test_score, 216 'test_score':test_score,
235 'num_finetuning_epochs':epoch}) 217 'num_finetuning_epochs':epoch})
236 218
219 if self.save_params:
220 save_params(self.classifier.all_params, "weights.dat")
221
237 print(('Optimization complete with best validation score of %f %%,' 222 print(('Optimization complete with best validation score of %f %%,'
238 'with test performance %f %%') % 223 'with test performance %f %%') %
239 (best_validation_loss * 100., test_score*100.)) 224 (best_validation_loss * 100., test_score*100.))
240 print ('The finetuning ran for %f minutes' % ((end_time-start_time)/60.)) 225 print ('The finetuning ran for %f minutes' % ((end_time-start_time)/60.))
241 226
242 227
243 228
229 def save_params(all_params, filename):
230 import pickle
231 with open(filename, 'wb') as f:
232 values = [p.value for p in all_params]
233
234 # -1 for HIGHEST_PROTOCOL
235 pickle.dump(values, f, -1)
236