comparison learner.py @ 131:57e6492644ec

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Wed, 07 May 2008 21:40:15 -0400
parents 3d8e40e7ed18 4c2280edcaf5
children f6505ec32dc3
comparison
equal deleted inserted replaced
130:3d8e40e7ed18 131:57e6492644ec
1 1
2 from dataset import AttributesHolder,AbstractFunction 2 from dataset import AttributesHolder,AbstractFunction
3 import compile 3 import compile
4 from theano import tensor as t
4 5
5 class Learner(AttributesHolder): 6 class Learner(AttributesHolder):
6 """Base class for learning algorithms, provides an interface 7 """Base class for learning algorithms, provides an interface
7 that allows various algorithms to be applicable to generic learning 8 that allows various algorithms to be applicable to generic learning
8 algorithms. 9 algorithms.
134 """ 135 """
135 Private helper function that maps a list of attribute names to a list 136 Private helper function that maps a list of attribute names to a list
136 of (optionally copies) values of attributes. 137 of (optionally copies) values of attributes.
137 """ 138 """
138 if return_copy: 139 if return_copy:
139 return [copy.deepcopy(self.__getattr__(name).data) for name in names] 140 return [copy.deepcopy(self.__getattribute__(name).data) for name in names]
140 else: 141 else:
141 return [self.__getattr__(name).data for name in names] 142 return [self.__getattribute__(name).data for name in names]
142 143
143 def updateInputAttributes(self): 144 def updateInputAttributes(self):
144 """ 145 """
145 A subset of self.attributeNames() which are the names of attributes needed by update() in order 146 A subset of self.attributeNames() which are the names of attributes needed by update() in order
146 to do its work. 147 to do its work.
250 def names2OpResults(self,names): 251 def names2OpResults(self,names):
251 """ 252 """
252 Private helper function that maps a list of attribute names to a list 253 Private helper function that maps a list of attribute names to a list
253 of corresponding Op Results (with the same name but with a '_' prefix). 254 of corresponding Op Results (with the same name but with a '_' prefix).
254 """ 255 """
255 return [self.__getattr__('_'+name).data for name in names] 256 return [self.__getattribute__('_'+name).data for name in names]
256 257
257 258
258 class MinibatchUpdatesTLearner(TLearner): 259 class MinibatchUpdatesTLearner(TLearner):
259 """ 260 """
260 This adds to TLearner a 261 This adds to TLearner a
353 train_stats_collector.update(minibatch_set) 354 train_stats_collector.update(minibatch_set)
354 stop = self.isLastEpoch() 355 stop = self.isLastEpoch()
355 self.updateEnd() 356 self.updateEnd()
356 return self.use 357 return self.use
357 358
358 class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner): 359 class OnlineGradientTLearner(MinibatchUpdatesTLearner):
359 """ 360 """
360 Specialization of MinibatchUpdatesTLearner in which the minibatch updates 361 Specialization of MinibatchUpdatesTLearner in which the minibatch updates
361 are obtained by performing an online (minibatch-based) gradient step. 362 are obtained by performing an online (minibatch-based) gradient step.
362 363
363 Sub-classes must define the following: 364 Sub-classes must define the following:
374 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS 375 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS
375 """ 376 """
376 self.truly_online=truly_online 377 self.truly_online=truly_online
377 378
378 # create the formulas for the gradient update 379 # create the formulas for the gradient update
379 old_params = [self.__getattr__("_"+name) for name in self.parameterAttributes()] 380 old_params = [self.__getattribute__("_"+name) for name in self.parameterAttributes()]
380 new_params_names = ["_new_"+name for name in self.parameterAttributes()] 381 new_params_names = ["_new_"+name for name in self.parameterAttributes()]
381 loss = self.__getattr__(self.lossAttribute()) 382 loss = self.__getattribute__("_"+self.lossAttribute())
382 self.setAttributes(new_params_names, 383 self.setAttributes(new_params_names,
383 [t.add_inplace(self.param, 384 [t.add_inplace(param,self._learning_rate*t.grad(loss,param))
384 self._learning_rate*t.grad(loss,param))
385 for param in old_params]) 385 for param in old_params])
386 386 MinibatchUpdatesTLearner.__init__(self)
387
387 def isLastEpoch(self): 388 def isLastEpoch(self):
388 return self.truly_online 389 return self.truly_online
389 390
390 def updateMinibatchInputAttributes(self): 391 def updateMinibatchInputAttributes(self):
391 return self.parameterAttributes() 392 return self.parameterAttributes()