Mercurial > pylearn
comparison learner.py @ 180:2698c0feeb54
mlp seems to work!
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Tue, 13 May 2008 15:35:43 -0400 |
parents | 69759976b3ac |
children | cb6b945acf5a |
comparison
equal
deleted
inserted
replaced
179:9911d2cc3c01 | 180:2698c0feeb54 |
---|---|
2 from exceptions import * | 2 from exceptions import * |
3 from dataset import AttributesHolder,ApplyFunctionDataSet,DataSet,CachedDataSet | 3 from dataset import AttributesHolder,ApplyFunctionDataSet,DataSet,CachedDataSet |
4 import theano | 4 import theano |
5 from theano import compile | 5 from theano import compile |
6 from theano import tensor as t | 6 from theano import tensor as t |
7 | 7 from misc import Print |
8 Print = lambda x: lambda y: y | |
9 | |
8 class Learner(AttributesHolder): | 10 class Learner(AttributesHolder): |
9 """ | 11 """ |
10 Base class for learning algorithms, provides an interface | 12 Base class for learning algorithms, provides an interface |
11 that allows various algorithms to be applicable to generic learning | 13 that allows various algorithms to be applicable to generic learning |
12 algorithms. | 14 algorithms. |
186 be used as fields or | 188 be used as fields or |
187 attributes in input/output datasets or in | 189 attributes in input/output datasets or in |
188 stats collectors. All these attributes | 190 stats collectors. All these attributes |
189 are expected to be theano.Result objects | 191 are expected to be theano.Result objects |
190 (with a .data property and recognized by | 192 (with a .data property and recognized by |
191 theano.Function for compilation). The sub-class | 193 theano.function for compilation). The sub-class |
192 constructor defines the relations between the | 194 constructor defines the relations between the |
193 Theano variables that may be used by 'use' | 195 Theano variables that may be used by 'use' |
194 and 'update' or by a stats collector. | 196 and 'update' or by a stats collector. |
195 - defaultOutputFields(input_fields): return a list of default | 197 - defaultOutputFields(input_fields): return a list of default |
196 dataset output fields when | 198 dataset output fields when |
208 | 210 |
209 @todo pousser dans Learner toute la poutine qui peut l'etre sans etre | 211 @todo pousser dans Learner toute la poutine qui peut l'etre sans etre |
210 dependant de Theano | 212 dependant de Theano |
211 """ | 213 """ |
212 | 214 |
213 def __init__(self): | 215 def __init__(self,linker="c|py"): |
214 Learner.__init__(self) | 216 Learner.__init__(self) |
215 self.use_functions_dictionary={} | 217 self.use_functions_dictionary={} |
218 self.linker=linker | |
216 | 219 |
217 def defaultOutputFields(self, input_fields): | 220 def defaultOutputFields(self, input_fields): |
218 """ | 221 """ |
219 Return a default list of output field names (to put in the output dataset). | 222 Return a default list of output field names (to put in the output dataset). |
220 This will be used when None are provided (as output_fields) by the caller of the 'use' method. | 223 This will be used when None are provided (as output_fields) by the caller of the 'use' method. |
236 key = (tuple(input_fields),tuple(output_fields)) | 239 key = (tuple(input_fields),tuple(output_fields)) |
237 if key not in self.use_functions_dictionary: | 240 if key not in self.use_functions_dictionary: |
238 use_input_attributes = self.useInputAttributes() | 241 use_input_attributes = self.useInputAttributes() |
239 use_output_attributes = self.useOutputAttributes() | 242 use_output_attributes = self.useOutputAttributes() |
240 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes), | 243 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes), |
241 self.names2OpResults(output_fields+use_output_attributes)) | 244 self.names2OpResults(output_fields+use_output_attributes), |
245 self.linker) | |
242 def f(*input_field_values): | 246 def f(*input_field_values): |
243 input_attribute_values = self.names2attributes(use_input_attributes) | 247 input_attribute_values = self.names2attributes(use_input_attributes) |
244 results = complete_f(*(list(input_field_values) + input_attribute_values)) | 248 results = complete_f(*(list(input_field_values) + input_attribute_values)) |
245 output_field_values = results[0:len(output_fields)] | 249 output_field_values = results[0:len(output_fields)] |
246 output_attribute_values = results[len(output_fields):len(results)] | 250 output_attribute_values = results[len(output_fields):len(results)] |
274 (whose names are specified by the user's subclass with the | 278 (whose names are specified by the user's subclass with the |
275 parameterAttributes() method) | 279 parameterAttributes() method) |
276 | 280 |
277 """ | 281 """ |
278 | 282 |
279 def __init__(self): | 283 def __init__(self,linker="c|py"): |
280 TLearner.__init__(self) | 284 TLearner.__init__(self,linker) |
281 self.update_minibatch_function = compile.function(self.names2OpResults(self.updateMinibatchInputAttributes()+ | 285 self.update_minibatch_function = compile.function(self.names2OpResults(self.updateMinibatchInputAttributes()+ |
282 self.updateMinibatchInputFields()), | 286 self.updateMinibatchInputFields()), |
283 self.names2OpResults(self.updateMinibatchOutputAttributes())) | 287 self.names2OpResults(self.updateMinibatchOutputAttributes()), |
288 linker) | |
284 self.update_end_function = compile.function(self.names2OpResults(self.updateEndInputAttributes()), | 289 self.update_end_function = compile.function(self.names2OpResults(self.updateEndInputAttributes()), |
285 self.names2OpResults(self.updateEndOutputAttributes())) | 290 self.names2OpResults(self.updateEndOutputAttributes()), |
291 linker) | |
286 | 292 |
287 def allocate(self, minibatch): | 293 def allocate(self, minibatch): |
288 """ | 294 """ |
289 This function is called at the beginning of each L{updateMinibatch} | 295 This function is called at the beginning of each L{updateMinibatch} |
290 and should be used to check that all required attributes have been | 296 and should be used to check that all required attributes have been |
367 | 373 |
368 Sub-classes must define the following: | 374 Sub-classes must define the following: |
369 - self._learning_rate (may be changed by the sub-class between epochs or minibatches) | 375 - self._learning_rate (may be changed by the sub-class between epochs or minibatches) |
370 - self.lossAttribute() = name of the loss field | 376 - self.lossAttribute() = name of the loss field |
371 """ | 377 """ |
372 def __init__(self,truly_online=False): | 378 def __init__(self,truly_online=False,linker="c|py"): |
373 """ | 379 """ |
374 If truly_online then only one pass is made through the training set passed to update(). | 380 If truly_online then only one pass is made through the training set passed to update(). |
375 | 381 |
376 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS | 382 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS |
377 """ | 383 """ |
380 # create the formulas for the gradient update | 386 # create the formulas for the gradient update |
381 old_params = [self.__getattribute__("_"+name) for name in self.parameterAttributes()] | 387 old_params = [self.__getattribute__("_"+name) for name in self.parameterAttributes()] |
382 new_params_names = ["_new_"+name for name in self.parameterAttributes()] | 388 new_params_names = ["_new_"+name for name in self.parameterAttributes()] |
383 loss = self.__getattribute__("_"+self.lossAttribute()) | 389 loss = self.__getattribute__("_"+self.lossAttribute()) |
384 self.setAttributes(new_params_names, | 390 self.setAttributes(new_params_names, |
385 [t.add_inplace(param,self._learning_rate*t.grad(loss,param)) | 391 [t.add_inplace(param,-self._learning_rate*Print("grad("+param.name+")")(t.grad(loss,param))) |
386 for param in old_params]) | 392 for param in old_params]) |
387 MinibatchUpdatesTLearner.__init__(self) | 393 MinibatchUpdatesTLearner.__init__(self,linker) |
388 | 394 |
389 | 395 |
390 def namesOfAttributesToComputeOutputs(self,output_names): | 396 def namesOfAttributesToComputeOutputs(self,output_names): |
391 """ | 397 """ |
392 The output_names are attribute names (not the corresponding Result names, which have leading _). | 398 The output_names are attribute names (not the corresponding Result names, which have leading _). |
406 | 412 |
407 def isLastEpoch(self): | 413 def isLastEpoch(self): |
408 return self.truly_online | 414 return self.truly_online |
409 | 415 |
410 def updateMinibatchInputAttributes(self): | 416 def updateMinibatchInputAttributes(self): |
411 return self.parameterAttributes() | 417 return self.parameterAttributes()+["learning_rate"] |
412 | 418 |
413 def updateMinibatchOutputAttributes(self): | 419 def updateMinibatchOutputAttributes(self): |
414 return ["new_"+name for name in self.parameterAttributes()] | 420 return ["new_"+name for name in self.parameterAttributes()] |
415 | 421 |
416 def updateEndInputAttributes(self): | 422 def updateEndInputAttributes(self): |