comparison learner.py @ 180:2698c0feeb54

mlp seems to work!
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Tue, 13 May 2008 15:35:43 -0400
parents 69759976b3ac
children cb6b945acf5a
comparison
equal deleted inserted replaced
179:9911d2cc3c01 180:2698c0feeb54
2 from exceptions import * 2 from exceptions import *
3 from dataset import AttributesHolder,ApplyFunctionDataSet,DataSet,CachedDataSet 3 from dataset import AttributesHolder,ApplyFunctionDataSet,DataSet,CachedDataSet
4 import theano 4 import theano
5 from theano import compile 5 from theano import compile
6 from theano import tensor as t 6 from theano import tensor as t
7 7 from misc import Print
8 Print = lambda x: lambda y: y
9
8 class Learner(AttributesHolder): 10 class Learner(AttributesHolder):
9 """ 11 """
10 Base class for learning algorithms, provides an interface 12 Base class for learning algorithms, provides an interface
11 that allows various algorithms to be applicable to generic learning 13 that allows various algorithms to be applicable to generic learning
12 algorithms. 14 algorithms.
186 be used as fields or 188 be used as fields or
187 attributes in input/output datasets or in 189 attributes in input/output datasets or in
188 stats collectors. All these attributes 190 stats collectors. All these attributes
189 are expected to be theano.Result objects 191 are expected to be theano.Result objects
190 (with a .data property and recognized by 192 (with a .data property and recognized by
191 theano.Function for compilation). The sub-class 193 theano.function for compilation). The sub-class
192 constructor defines the relations between the 194 constructor defines the relations between the
193 Theano variables that may be used by 'use' 195 Theano variables that may be used by 'use'
194 and 'update' or by a stats collector. 196 and 'update' or by a stats collector.
195 - defaultOutputFields(input_fields): return a list of default 197 - defaultOutputFields(input_fields): return a list of default
196 dataset output fields when 198 dataset output fields when
208 210
209 @todo pousser dans Learner toute la poutine qui peut l'etre sans etre 211 @todo pousser dans Learner toute la poutine qui peut l'etre sans etre
210 dependant de Theano 212 dependant de Theano
211 """ 213 """
212 214
213 def __init__(self): 215 def __init__(self,linker="c|py"):
214 Learner.__init__(self) 216 Learner.__init__(self)
215 self.use_functions_dictionary={} 217 self.use_functions_dictionary={}
218 self.linker=linker
216 219
217 def defaultOutputFields(self, input_fields): 220 def defaultOutputFields(self, input_fields):
218 """ 221 """
219 Return a default list of output field names (to put in the output dataset). 222 Return a default list of output field names (to put in the output dataset).
220 This will be used when None are provided (as output_fields) by the caller of the 'use' method. 223 This will be used when None are provided (as output_fields) by the caller of the 'use' method.
236 key = (tuple(input_fields),tuple(output_fields)) 239 key = (tuple(input_fields),tuple(output_fields))
237 if key not in self.use_functions_dictionary: 240 if key not in self.use_functions_dictionary:
238 use_input_attributes = self.useInputAttributes() 241 use_input_attributes = self.useInputAttributes()
239 use_output_attributes = self.useOutputAttributes() 242 use_output_attributes = self.useOutputAttributes()
240 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes), 243 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes),
241 self.names2OpResults(output_fields+use_output_attributes)) 244 self.names2OpResults(output_fields+use_output_attributes),
245 self.linker)
242 def f(*input_field_values): 246 def f(*input_field_values):
243 input_attribute_values = self.names2attributes(use_input_attributes) 247 input_attribute_values = self.names2attributes(use_input_attributes)
244 results = complete_f(*(list(input_field_values) + input_attribute_values)) 248 results = complete_f(*(list(input_field_values) + input_attribute_values))
245 output_field_values = results[0:len(output_fields)] 249 output_field_values = results[0:len(output_fields)]
246 output_attribute_values = results[len(output_fields):len(results)] 250 output_attribute_values = results[len(output_fields):len(results)]
274 (whose names are specified by the user's subclass with the 278 (whose names are specified by the user's subclass with the
275 parameterAttributes() method) 279 parameterAttributes() method)
276 280
277 """ 281 """
278 282
279 def __init__(self): 283 def __init__(self,linker="c|py"):
280 TLearner.__init__(self) 284 TLearner.__init__(self,linker)
281 self.update_minibatch_function = compile.function(self.names2OpResults(self.updateMinibatchInputAttributes()+ 285 self.update_minibatch_function = compile.function(self.names2OpResults(self.updateMinibatchInputAttributes()+
282 self.updateMinibatchInputFields()), 286 self.updateMinibatchInputFields()),
283 self.names2OpResults(self.updateMinibatchOutputAttributes())) 287 self.names2OpResults(self.updateMinibatchOutputAttributes()),
288 linker)
284 self.update_end_function = compile.function(self.names2OpResults(self.updateEndInputAttributes()), 289 self.update_end_function = compile.function(self.names2OpResults(self.updateEndInputAttributes()),
285 self.names2OpResults(self.updateEndOutputAttributes())) 290 self.names2OpResults(self.updateEndOutputAttributes()),
291 linker)
286 292
287 def allocate(self, minibatch): 293 def allocate(self, minibatch):
288 """ 294 """
289 This function is called at the beginning of each L{updateMinibatch} 295 This function is called at the beginning of each L{updateMinibatch}
290 and should be used to check that all required attributes have been 296 and should be used to check that all required attributes have been
367 373
368 Sub-classes must define the following: 374 Sub-classes must define the following:
369 - self._learning_rate (may be changed by the sub-class between epochs or minibatches) 375 - self._learning_rate (may be changed by the sub-class between epochs or minibatches)
370 - self.lossAttribute() = name of the loss field 376 - self.lossAttribute() = name of the loss field
371 """ 377 """
372 def __init__(self,truly_online=False): 378 def __init__(self,truly_online=False,linker="c|py"):
373 """ 379 """
374 If truly_online then only one pass is made through the training set passed to update(). 380 If truly_online then only one pass is made through the training set passed to update().
375 381
376 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS 382 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS
377 """ 383 """
380 # create the formulas for the gradient update 386 # create the formulas for the gradient update
381 old_params = [self.__getattribute__("_"+name) for name in self.parameterAttributes()] 387 old_params = [self.__getattribute__("_"+name) for name in self.parameterAttributes()]
382 new_params_names = ["_new_"+name for name in self.parameterAttributes()] 388 new_params_names = ["_new_"+name for name in self.parameterAttributes()]
383 loss = self.__getattribute__("_"+self.lossAttribute()) 389 loss = self.__getattribute__("_"+self.lossAttribute())
384 self.setAttributes(new_params_names, 390 self.setAttributes(new_params_names,
385 [t.add_inplace(param,self._learning_rate*t.grad(loss,param)) 391 [t.add_inplace(param,-self._learning_rate*Print("grad("+param.name+")")(t.grad(loss,param)))
386 for param in old_params]) 392 for param in old_params])
387 MinibatchUpdatesTLearner.__init__(self) 393 MinibatchUpdatesTLearner.__init__(self,linker)
388 394
389 395
390 def namesOfAttributesToComputeOutputs(self,output_names): 396 def namesOfAttributesToComputeOutputs(self,output_names):
391 """ 397 """
392 The output_names are attribute names (not the corresponding Result names, which have leading _). 398 The output_names are attribute names (not the corresponding Result names, which have leading _).
406 412
407 def isLastEpoch(self): 413 def isLastEpoch(self):
408 return self.truly_online 414 return self.truly_online
409 415
410 def updateMinibatchInputAttributes(self): 416 def updateMinibatchInputAttributes(self):
411 return self.parameterAttributes() 417 return self.parameterAttributes()+["learning_rate"]
412 418
413 def updateMinibatchOutputAttributes(self): 419 def updateMinibatchOutputAttributes(self):
414 return ["new_"+name for name in self.parameterAttributes()] 420 return ["new_"+name for name in self.parameterAttributes()]
415 421
416 def updateEndInputAttributes(self): 422 def updateEndInputAttributes(self):