comparison learner.py @ 111:88257dfedf8c

Added another work in progress, for mlp's
author bengioy@bengiomac.local
date Wed, 07 May 2008 09:16:04 -0400
parents 8fa1ef2411a0
children d0a1bd0378c6
comparison
equal deleted inserted replaced
110:8fa1ef2411a0 111:88257dfedf8c
1 1
2 from dataset import * 2 from dataset import *
3 from compile import Function
3 4
4 class Learner(AttributesHolder): 5 class Learner(AttributesHolder):
5 """Base class for learning algorithms, provides an interface 6 """Base class for learning algorithms, provides an interface
6 that allows various algorithms to be applicable to generic learning 7 that allows various algorithms to be applicable to generic learning
7 algorithms. 8 algorithms.
82 83
83 def updateOutputAttributes(self): 84 def updateOutputAttributes(self):
84 """ 85 """
85 A subset of self.attributeNames() which are the names of attributes modified/created by update() in order 86 A subset of self.attributeNames() which are the names of attributes modified/created by update() in order
86 to do its work. 87 to do its work.
87 """ 88
88 raise AbstractFunction() 89 By default these are inferred from the various update output attributes:
90 """
91 return ["parameters"] + self.updateMinibatchOutputAttributes() + self.updateEndOutputAttributes()
89 92
90 def useOutputAttributes(self): 93 def useOutputAttributes(self):
91 """ 94 """
92 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order 95 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order
93 to do its work. 96 to do its work.
249 output_dataset.setAttributes(test_stats_collector.attributeNames(), 252 output_dataset.setAttributes(test_stats_collector.attributeNames(),
250 test_stats_collector.attributes()) 253 test_stats_collector.attributes())
251 return output_dataset 254 return output_dataset
252 255
253 256
254 class OneShotTLearner(TLearner): 257 class MinibatchUpdatesTLearner(TLearner):
255 """ 258 """
256 This adds to TLearner a 259 This adds to TLearner a
257 - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch(): 260 - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch():
258 functions executed at the beginning, the end, in the middle 261 functions executed at the beginning, the end, in the middle
259 (for each minibatch) of the update method, and at the end 262 (for each minibatch) of the update method, and at the end
260 of each epoch. This model only 263 of each epoch. This model only
261 works for 'online' or one-shot learning that requires 264 works for 'online' or one-shot learning that requires
262 going only once through the training data. For more complicated 265 going only once through the training data. For more complicated
263 models, more specialized subclasses of TLearner should be used 266 models, more specialized subclasses of TLearner should be used
264 or a learning-algorithm specific update method should be defined. 267 or a learning-algorithm specific update method should be defined.
268
269 - a 'parameters' attribute which is a list of parameters (whose names are
270 specified by the user's subclass with the parameterAttributes() method)
271
265 """ 272 """
266 273
267 def __init__(self): 274 def __init__(self):
268 TLearner.__init__(self) 275 TLearner.__init__(self)
269 self.update_minibatch_function = 276 self.update_minibatch_function =
286 raise AbstractFunction() 293 raise AbstractFunction()
287 294
288 def updateEndOutputAttributes(self): 295 def updateEndOutputAttributes(self):
289 raise AbstractFunction() 296 raise AbstractFunction()
290 297
291 def updateStart(self): pass 298 def parameterAttributes(self):
299 raise AbstractFunction()
300
301 def updateStart(self): pass
292 302
293 def updateEnd(self): 303 def updateEnd(self):
294 self.setAttributes(self.updateEndOutputAttributes(), 304 self.setAttributes(self.updateEndOutputAttributes(),
295 self.update_end_function 305 self.update_end_function
296 (self.names2attributes(self.updateEndInputAttributes()))) 306 (self.names2attributes(self.updateEndInputAttributes())))
307 self.parameters = self.names2attributes(self.parameterAttributes())
297 308
298 def updateMinibatch(self,minibatch): 309 def updateMinibatch(self,minibatch):
299 # make sure all required fields are allocated and initialized 310 # make sure all required fields are allocated and initialized
300 self.allocate(minibatch) 311 self.allocate(minibatch)
301 self.setAttributes(self.updateMinibatchOutputAttributes(), 312 self.setAttributes(self.updateMinibatchOutputAttributes(),
329 train_stats_collector.update(minibatch_set) 340 train_stats_collector.update(minibatch_set)
330 stop = self.isLastEpoch() 341 stop = self.isLastEpoch()
331 self.updateEnd() 342 self.updateEnd()
332 return self.use 343 return self.use
333 344
345 class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner):
346 """
347 Specialization of MinibatchUpdatesTLearner in which the minibatch updates
348 are obtained by performing an online (minibatch-based) gradient step.
349
350 Sub-classes must define the following methods:
351
352 """
353 def __init__(self,truly_online=False):
354 """
355 If truly_online then only one pass is made through the training set passed to update().
356
357 """
358 self.truly_online=truly_online
359
360 def isLastEpoch(self):
361 return self.truly_online
362
363