Mercurial > pylearn
comparison learner.py @ 131:57e6492644ec
Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Wed, 07 May 2008 21:40:15 -0400 |
parents | 3d8e40e7ed18 4c2280edcaf5 |
children | f6505ec32dc3 |
comparison
equal
deleted
inserted
replaced
130:3d8e40e7ed18 | 131:57e6492644ec |
---|---|
1 | 1 |
2 from dataset import AttributesHolder,AbstractFunction | 2 from dataset import AttributesHolder,AbstractFunction |
3 import compile | 3 import compile |
4 from theano import tensor as t | |
4 | 5 |
5 class Learner(AttributesHolder): | 6 class Learner(AttributesHolder): |
6 """Base class for learning algorithms, provides an interface | 7 """Base class for learning algorithms, provides an interface |
7 that allows various algorithms to be applicable to generic learning | 8 that allows various algorithms to be applicable to generic learning |
8 algorithms. | 9 algorithms. |
134 """ | 135 """ |
135 Private helper function that maps a list of attribute names to a list | 136 Private helper function that maps a list of attribute names to a list |
136 of (optionally copies) values of attributes. | 137 of (optionally copies) values of attributes. |
137 """ | 138 """ |
138 if return_copy: | 139 if return_copy: |
139 return [copy.deepcopy(self.__getattr__(name).data) for name in names] | 140 return [copy.deepcopy(self.__getattribute__(name).data) for name in names] |
140 else: | 141 else: |
141 return [self.__getattr__(name).data for name in names] | 142 return [self.__getattribute__(name).data for name in names] |
142 | 143 |
143 def updateInputAttributes(self): | 144 def updateInputAttributes(self): |
144 """ | 145 """ |
145 A subset of self.attributeNames() which are the names of attributes needed by update() in order | 146 A subset of self.attributeNames() which are the names of attributes needed by update() in order |
146 to do its work. | 147 to do its work. |
250 def names2OpResults(self,names): | 251 def names2OpResults(self,names): |
251 """ | 252 """ |
252 Private helper function that maps a list of attribute names to a list | 253 Private helper function that maps a list of attribute names to a list |
253 of corresponding Op Results (with the same name but with a '_' prefix). | 254 of corresponding Op Results (with the same name but with a '_' prefix). |
254 """ | 255 """ |
255 return [self.__getattr__('_'+name).data for name in names] | 256 return [self.__getattribute__('_'+name).data for name in names] |
256 | 257 |
257 | 258 |
258 class MinibatchUpdatesTLearner(TLearner): | 259 class MinibatchUpdatesTLearner(TLearner): |
259 """ | 260 """ |
260 This adds to TLearner a | 261 This adds to TLearner a |
353 train_stats_collector.update(minibatch_set) | 354 train_stats_collector.update(minibatch_set) |
354 stop = self.isLastEpoch() | 355 stop = self.isLastEpoch() |
355 self.updateEnd() | 356 self.updateEnd() |
356 return self.use | 357 return self.use |
357 | 358 |
358 class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner): | 359 class OnlineGradientTLearner(MinibatchUpdatesTLearner): |
359 """ | 360 """ |
360 Specialization of MinibatchUpdatesTLearner in which the minibatch updates | 361 Specialization of MinibatchUpdatesTLearner in which the minibatch updates |
361 are obtained by performing an online (minibatch-based) gradient step. | 362 are obtained by performing an online (minibatch-based) gradient step. |
362 | 363 |
363 Sub-classes must define the following: | 364 Sub-classes must define the following: |
374 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS | 375 SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS |
375 """ | 376 """ |
376 self.truly_online=truly_online | 377 self.truly_online=truly_online |
377 | 378 |
378 # create the formulas for the gradient update | 379 # create the formulas for the gradient update |
379 old_params = [self.__getattr__("_"+name) for name in self.parameterAttributes()] | 380 old_params = [self.__getattribute__("_"+name) for name in self.parameterAttributes()] |
380 new_params_names = ["_new_"+name for name in self.parameterAttributes()] | 381 new_params_names = ["_new_"+name for name in self.parameterAttributes()] |
381 loss = self.__getattr__(self.lossAttribute()) | 382 loss = self.__getattribute__("_"+self.lossAttribute()) |
382 self.setAttributes(new_params_names, | 383 self.setAttributes(new_params_names, |
383 [t.add_inplace(self.param, | 384 [t.add_inplace(param,self._learning_rate*t.grad(loss,param)) |
384 self._learning_rate*t.grad(loss,param)) | |
385 for param in old_params]) | 385 for param in old_params]) |
386 | 386 MinibatchUpdatesTLearner.__init__(self) |
387 | |
387 def isLastEpoch(self): | 388 def isLastEpoch(self): |
388 return self.truly_online | 389 return self.truly_online |
389 | 390 |
390 def updateMinibatchInputAttributes(self): | 391 def updateMinibatchInputAttributes(self): |
391 return self.parameterAttributes() | 392 return self.parameterAttributes() |