comparison learner.py @ 119:7ffecde9dadc

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Wed, 07 May 2008 15:08:18 -0400
parents d0a1bd0378c6
children 4efe6d36c061
comparison
equal deleted inserted replaced
116:9330d941fa1f 119:7ffecde9dadc
1 1
2 from dataset import * 2 from dataset import AttributesHolder
3 from compile import Function 3 import compile
4 4
5 class Learner(AttributesHolder): 5 class Learner(AttributesHolder):
6 """Base class for learning algorithms, provides an interface 6 """Base class for learning algorithms, provides an interface
7 that allows various algorithms to be applicable to generic learning 7 that allows various algorithms to be applicable to generic learning
8 algorithms. 8 algorithms.
171 output_fields.append(attribute) 171 output_fields.append(attribute)
172 key = (input_fields,output_fields) 172 key = (input_fields,output_fields)
173 if key not in self.use_functions_dictionary: 173 if key not in self.use_functions_dictionary:
174 use_input_attributes = self.useInputAttributes() 174 use_input_attributes = self.useInputAttributes()
175 use_output_attributes = self.useOutputAttributes() 175 use_output_attributes = self.useOutputAttributes()
176 complete_f = Function(self.names2OpResults(input_fields+use_input_attributes), 176 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes),
177 self.names2OpResults(output_fields+use_output_attributes)) 177 self.names2OpResults(output_fields+use_output_attributes))
178 def f(*input_field_values): 178 def f(*input_field_values):
179 input_attribute_values = self.names2attributes(use_input_attributes) 179 input_attribute_values = self.names2attributes(use_input_attributes)
180 results = complete_f(*(input_field_values + input_attribute_values)) 180 results = complete_f(*(input_field_values + input_attribute_values))
181 output_field_values = results[0:len(output_fields)] 181 output_field_values = results[0:len(output_fields)]
182 output_attribute_values = results[len(output_fields):len(results)] 182 output_attribute_values = results[len(output_fields):len(results)]
271 271
272 """ 272 """
273 273
274 def __init__(self): 274 def __init__(self):
275 TLearner.__init__(self) 275 TLearner.__init__(self)
276 self.update_minibatch_function = 276 self.update_minibatch_function = compile.function
277 Function(self.names2OpResults(self.updateMinibatchOutputAttributes()+ 277 (self.names2OpResults(self.updateMinibatchOutputAttributes()+
278 self.updateMinibatchInputFields()), 278 self.updateMinibatchInputFields()),
279 self.names2OpResults(self.updateMinibatchOutputAttributes())) 279 self.names2OpResults(self.updateMinibatchOutputAttributes()))
280 self.update_end_function = Function(self.names2OpResults(self.updateEndInputAttributes()), 280 self.update_end_function = compile.function
281 self.names2OpResults(self.updateEndOutputAttributes())) 281 (self.names2OpResults(self.updateEndInputAttributes()),
282 self.names2OpResults(self.updateEndOutputAttributes()))
282 283
283 def updateMinibatchInputFields(self): 284 def updateMinibatchInputFields(self):
284 raise AbstractFunction() 285 raise AbstractFunction()
285 286
286 def updateMinibatchInputAttributes(self): 287 def updateMinibatchInputAttributes(self):
308 309
309 def updateMinibatch(self,minibatch): 310 def updateMinibatch(self,minibatch):
310 # make sure all required fields are allocated and initialized 311 # make sure all required fields are allocated and initialized
311 self.allocate(minibatch) 312 self.allocate(minibatch)
312 self.setAttributes(self.updateMinibatchOutputAttributes(), 313 self.setAttributes(self.updateMinibatchOutputAttributes(),
313 self.update_minibatch_function(*(self.names2attributes(self.updateMinibatchInputAttributes())) 314 # concatenate the attribute values and field values and then apply update fn
315 self.update_minibatch_function(*(self.names2attributes
316 (self.updateMinibatchInputAttributes()))
314 + minibatch(self.updateMinibatchInputFields()))) 317 + minibatch(self.updateMinibatchInputFields())))
315 318
316 def isLastEpoch(self): 319 def isLastEpoch(self):
317 """ 320 """
318 This method is called at the end of each epoch (cycling over the training set). 321 This method is called at the end of each epoch (cycling over the training set).
345 class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner): 348 class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner):
346 """ 349 """
347 Specialization of MinibatchUpdatesTLearner in which the minibatch updates 350 Specialization of MinibatchUpdatesTLearner in which the minibatch updates
348 are obtained by performing an online (minibatch-based) gradient step. 351 are obtained by performing an online (minibatch-based) gradient step.
349 352
350 Sub-classes must define the following methods: 353 Sub-classes must define the following:
351 354
355 self._learning_rate (may be changed by the sub-class between epochs or minibatches)
356
357 self.lossAttribute() = name of the loss field
358
352 """ 359 """
353 def __init__(self,truly_online=False): 360 def __init__(self,truly_online=False):
354 """ 361 """
355 If truly_online then only one pass is made through the training set passed to update(). 362 If truly_online then only one pass is made through the training set passed to update().
356 363
364 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS
357 """ 365 """
358 self.truly_online=truly_online 366 self.truly_online=truly_online
367
368 # create the formulas for the gradient update
369 old_params = [self.__getattr__("_"+name) for name in self.parameterAttributes()]
370 new_params_names = ["_new_"+name for name in self.parameterAttributes()]
371 loss = self.__getattr__(self.lossAttribute())
372 self.setAttributes(new_params_names,
373 [t.add_inplace(self.param,
374 self._learning_rate*t.grad(loss,param))
375 for param in old_params])
359 376
360 def isLastEpoch(self): 377 def isLastEpoch(self):
361 return self.truly_online 378 return self.truly_online
362 379
363 380 def updateMinibatchInputAttributes(self):
381 return self.parameterAttributes()
382
383 def updateMinibatchOutputAttributes(self):
384 return ["_new"+name for name in self.parameterAttributes()]
385
386 def updateEndInputAttributes(self):
387 return self.parameterAttributes()
388
389