Mercurial > pylearn
comparison learner.py @ 119:7ffecde9dadc
Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Wed, 07 May 2008 15:08:18 -0400 |
parents | d0a1bd0378c6 |
children | 4efe6d36c061 |
comparison
equal
deleted
inserted
replaced
116:9330d941fa1f | 119:7ffecde9dadc |
---|---|
1 | 1 |
2 from dataset import * | 2 from dataset import AttributesHolder |
3 from compile import Function | 3 import compile |
4 | 4 |
5 class Learner(AttributesHolder): | 5 class Learner(AttributesHolder): |
6 """Base class for learning algorithms, provides an interface | 6 """Base class for learning algorithms, provides an interface |
7 that allows various algorithms to be applicable to generic learning | 7 that allows various algorithms to be applicable to generic learning |
8 algorithms. | 8 algorithms. |
171 output_fields.append(attribute) | 171 output_fields.append(attribute) |
172 key = (input_fields,output_fields) | 172 key = (input_fields,output_fields) |
173 if key not in self.use_functions_dictionary: | 173 if key not in self.use_functions_dictionary: |
174 use_input_attributes = self.useInputAttributes() | 174 use_input_attributes = self.useInputAttributes() |
175 use_output_attributes = self.useOutputAttributes() | 175 use_output_attributes = self.useOutputAttributes() |
176 complete_f = Function(self.names2OpResults(input_fields+use_input_attributes), | 176 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes), |
177 self.names2OpResults(output_fields+use_output_attributes)) | 177 self.names2OpResults(output_fields+use_output_attributes)) |
178 def f(*input_field_values): | 178 def f(*input_field_values): |
179 input_attribute_values = self.names2attributes(use_input_attributes) | 179 input_attribute_values = self.names2attributes(use_input_attributes) |
180 results = complete_f(*(input_field_values + input_attribute_values)) | 180 results = complete_f(*(input_field_values + input_attribute_values)) |
181 output_field_values = results[0:len(output_fields)] | 181 output_field_values = results[0:len(output_fields)] |
182 output_attribute_values = results[len(output_fields):len(results)] | 182 output_attribute_values = results[len(output_fields):len(results)] |
271 | 271 |
272 """ | 272 """ |
273 | 273 |
274 def __init__(self): | 274 def __init__(self): |
275 TLearner.__init__(self) | 275 TLearner.__init__(self) |
276 self.update_minibatch_function = | 276 self.update_minibatch_function = compile.function |
277 Function(self.names2OpResults(self.updateMinibatchOutputAttributes()+ | 277 (self.names2OpResults(self.updateMinibatchOutputAttributes()+ |
278 self.updateMinibatchInputFields()), | 278 self.updateMinibatchInputFields()), |
279 self.names2OpResults(self.updateMinibatchOutputAttributes())) | 279 self.names2OpResults(self.updateMinibatchOutputAttributes())) |
280 self.update_end_function = Function(self.names2OpResults(self.updateEndInputAttributes()), | 280 self.update_end_function = compile.function |
281 self.names2OpResults(self.updateEndOutputAttributes())) | 281 (self.names2OpResults(self.updateEndInputAttributes()), |
282 self.names2OpResults(self.updateEndOutputAttributes())) | |
282 | 283 |
283 def updateMinibatchInputFields(self): | 284 def updateMinibatchInputFields(self): |
284 raise AbstractFunction() | 285 raise AbstractFunction() |
285 | 286 |
286 def updateMinibatchInputAttributes(self): | 287 def updateMinibatchInputAttributes(self): |
308 | 309 |
309 def updateMinibatch(self,minibatch): | 310 def updateMinibatch(self,minibatch): |
310 # make sure all required fields are allocated and initialized | 311 # make sure all required fields are allocated and initialized |
311 self.allocate(minibatch) | 312 self.allocate(minibatch) |
312 self.setAttributes(self.updateMinibatchOutputAttributes(), | 313 self.setAttributes(self.updateMinibatchOutputAttributes(), |
313 self.update_minibatch_function(*(self.names2attributes(self.updateMinibatchInputAttributes())) | 314 # concatenate the attribute values and field values and then apply update fn |
315 self.update_minibatch_function(*(self.names2attributes | |
316 (self.updateMinibatchInputAttributes())) | |
314 + minibatch(self.updateMinibatchInputFields()))) | 317 + minibatch(self.updateMinibatchInputFields()))) |
315 | 318 |
316 def isLastEpoch(self): | 319 def isLastEpoch(self): |
317 """ | 320 """ |
318 This method is called at the end of each epoch (cycling over the training set). | 321 This method is called at the end of each epoch (cycling over the training set). |
345 class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner): | 348 class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner): |
346 """ | 349 """ |
347 Specialization of MinibatchUpdatesTLearner in which the minibatch updates | 350 Specialization of MinibatchUpdatesTLearner in which the minibatch updates |
348 are obtained by performing an online (minibatch-based) gradient step. | 351 are obtained by performing an online (minibatch-based) gradient step. |
349 | 352 |
350 Sub-classes must define the following methods: | 353 Sub-classes must define the following: |
351 | 354 |
355 self._learning_rate (may be changed by the sub-class between epochs or minibatches) | |
356 | |
357 self.lossAttribute() = name of the loss field | |
358 | |
352 """ | 359 """ |
353 def __init__(self,truly_online=False): | 360 def __init__(self,truly_online=False): |
354 """ | 361 """ |
355 If truly_online then only one pass is made through the training set passed to update(). | 362 If truly_online then only one pass is made through the training set passed to update(). |
356 | 363 |
364 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS | |
357 """ | 365 """ |
358 self.truly_online=truly_online | 366 self.truly_online=truly_online |
367 | |
368 # create the formulas for the gradient update | |
369 old_params = [self.__getattr__("_"+name) for name in self.parameterAttributes()] | |
370 new_params_names = ["_new_"+name for name in self.parameterAttributes()] | |
371 loss = self.__getattr__(self.lossAttribute()) | |
372 self.setAttributes(new_params_names, | |
373 [t.add_inplace(self.param, | |
374 self._learning_rate*t.grad(loss,param)) | |
375 for param in old_params]) | |
359 | 376 |
360 def isLastEpoch(self): | 377 def isLastEpoch(self): |
361 return self.truly_online | 378 return self.truly_online |
362 | 379 |
363 | 380 def updateMinibatchInputAttributes(self): |
381 return self.parameterAttributes() | |
382 | |
383 def updateMinibatchOutputAttributes(self): | |
384 return ["_new"+name for name in self.parameterAttributes()] | |
385 | |
386 def updateEndInputAttributes(self): | |
387 return self.parameterAttributes() | |
388 | |
389 |