Mercurial > pylearn
comparison learner.py @ 111:88257dfedf8c
Added another work in progress, for mlp's
author | bengioy@bengiomac.local |
---|---|
date | Wed, 07 May 2008 09:16:04 -0400 |
parents | 8fa1ef2411a0 |
children | d0a1bd0378c6 |
comparison
equal
deleted
inserted
replaced
110:8fa1ef2411a0 | 111:88257dfedf8c |
---|---|
1 | 1 |
2 from dataset import * | 2 from dataset import * |
3 from compile import Function | |
3 | 4 |
4 class Learner(AttributesHolder): | 5 class Learner(AttributesHolder): |
5 """Base class for learning algorithms, provides an interface | 6 """Base class for learning algorithms, provides an interface |
6 that allows various algorithms to be applicable to generic learning | 7 that allows various algorithms to be applicable to generic learning |
7 algorithms. | 8 algorithms. |
82 | 83 |
83 def updateOutputAttributes(self): | 84 def updateOutputAttributes(self): |
84 """ | 85 """ |
85 A subset of self.attributeNames() which are the names of attributes modified/created by update() in order | 86 A subset of self.attributeNames() which are the names of attributes modified/created by update() in order |
86 to do its work. | 87 to do its work. |
87 """ | 88 |
88 raise AbstractFunction() | 89 By default these are inferred from the various update output attributes: |
90 """ | |
91 return ["parameters"] + self.updateMinibatchOutputAttributes() + self.updateEndOutputAttributes() | |
89 | 92 |
90 def useOutputAttributes(self): | 93 def useOutputAttributes(self): |
91 """ | 94 """ |
92 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order | 95 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order |
93 to do its work. | 96 to do its work. |
249 output_dataset.setAttributes(test_stats_collector.attributeNames(), | 252 output_dataset.setAttributes(test_stats_collector.attributeNames(), |
250 test_stats_collector.attributes()) | 253 test_stats_collector.attributes()) |
251 return output_dataset | 254 return output_dataset |
252 | 255 |
253 | 256 |
254 class OneShotTLearner(TLearner): | 257 class MinibatchUpdatesTLearner(TLearner): |
255 """ | 258 """ |
256 This adds to TLearner a | 259 This adds to TLearner a |
257 - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch(): | 260 - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch(): |
258 functions executed at the beginning, the end, in the middle | 261 functions executed at the beginning, the end, in the middle |
259 (for each minibatch) of the update method, and at the end | 262 (for each minibatch) of the update method, and at the end |
260 of each epoch. This model only | 263 of each epoch. This model only |
261 works for 'online' or one-shot learning that requires | 264 works for 'online' or one-shot learning that requires |
262 going only once through the training data. For more complicated | 265 going only once through the training data. For more complicated |
263 models, more specialized subclasses of TLearner should be used | 266 models, more specialized subclasses of TLearner should be used |
264 or a learning-algorithm specific update method should be defined. | 267 or a learning-algorithm specific update method should be defined. |
268 | |
269 - a 'parameters' attribute which is a list of parameters (whose names are | |
270 specified by the user's subclass with the parameterAttributes() method) | |
271 | |
265 """ | 272 """ |
266 | 273 |
267 def __init__(self): | 274 def __init__(self): |
268 TLearner.__init__(self) | 275 TLearner.__init__(self) |
269 self.update_minibatch_function = | 276 self.update_minibatch_function = |
286 raise AbstractFunction() | 293 raise AbstractFunction() |
287 | 294 |
288 def updateEndOutputAttributes(self): | 295 def updateEndOutputAttributes(self): |
289 raise AbstractFunction() | 296 raise AbstractFunction() |
290 | 297 |
291 def updateStart(self): pass | 298 def parameterAttributes(self): |
299 raise AbstractFunction() | |
300 | |
301 def updateStart(self): pass | |
292 | 302 |
293 def updateEnd(self): | 303 def updateEnd(self): |
294 self.setAttributes(self.updateEndOutputAttributes(), | 304 self.setAttributes(self.updateEndOutputAttributes(), |
295 self.update_end_function | 305 self.update_end_function |
296 (self.names2attributes(self.updateEndInputAttributes()))) | 306 (self.names2attributes(self.updateEndInputAttributes()))) |
307 self.parameters = self.names2attributes(self.parameterAttributes()) | |
297 | 308 |
298 def updateMinibatch(self,minibatch): | 309 def updateMinibatch(self,minibatch): |
299 # make sure all required fields are allocated and initialized | 310 # make sure all required fields are allocated and initialized |
300 self.allocate(minibatch) | 311 self.allocate(minibatch) |
301 self.setAttributes(self.updateMinibatchOutputAttributes(), | 312 self.setAttributes(self.updateMinibatchOutputAttributes(), |
329 train_stats_collector.update(minibatch_set) | 340 train_stats_collector.update(minibatch_set) |
330 stop = self.isLastEpoch() | 341 stop = self.isLastEpoch() |
331 self.updateEnd() | 342 self.updateEnd() |
332 return self.use | 343 return self.use |
333 | 344 |
345 class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner): | |
346 """ | |
347 Specialization of MinibatchUpdatesTLearner in which the minibatch updates | |
348 are obtained by performing an online (minibatch-based) gradient step. | |
349 | |
350 Sub-classes must define the following methods: | |
351 | |
352 """ | |
353 def __init__(self,truly_online=False): | |
354 """ | |
355 If truly_online then only one pass is made through the training set passed to update(). | |
356 | |
357 """ | |
358 self.truly_online=truly_online | |
359 | |
360 def isLastEpoch(self): | |
361 return self.truly_online | |
362 | |
363 |