Mercurial > pylearn
comparison learner.py @ 133:b4657441dd65
Corrected typos
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Fri, 09 May 2008 13:38:54 -0400 |
parents | f6505ec32dc3 |
children | 3f4e5c9bdc5e |
comparison
equal
deleted
inserted
replaced
132:f6505ec32dc3 | 133:b4657441dd65 |
---|---|
45 """ | 45 """ |
46 Train a learner from scratch using the provided training set, | 46 Train a learner from scratch using the provided training set, |
47 and return the learned function. | 47 and return the learned function. |
48 """ | 48 """ |
49 self.forget() | 49 self.forget() |
50 return self.update(learning_task,train_stats_collector) | 50 return self.update(training_set,train_stats_collector) |
51 | 51 |
52 def use(self,input_dataset,output_fieldnames=None, | 52 def use(self,input_dataset,output_fieldnames=None, |
53 test_stats_collector=None,copy_inputs=True, | 53 test_stats_collector=None,copy_inputs=True, |
54 put_stats_in_output_dataset=True, | 54 put_stats_in_output_dataset=True, |
55 output_attributes=[]): | 55 output_attributes=[]): |
252 def names2OpResults(self,names): | 252 def names2OpResults(self,names): |
253 """ | 253 """ |
254 Private helper function that maps a list of attribute names to a list | 254 Private helper function that maps a list of attribute names to a list |
255 of corresponding Op Results (with the same name but with a '_' prefix). | 255 of corresponding Op Results (with the same name but with a '_' prefix). |
256 """ | 256 """ |
257 return [self.__getattribute__('_'+name).data for name in names] | 257 return [self.__getattribute__('_'+name) for name in names] |
258 | 258 |
259 | 259 |
260 class MinibatchUpdatesTLearner(TLearner): | 260 class MinibatchUpdatesTLearner(TLearner): |
261 """ | 261 """ |
262 This adds to L{TLearner} a | 262 This adds to L{TLearner} a |
309 raise AbstractFunction() | 309 raise AbstractFunction() |
310 | 310 |
311 def parameterAttributes(self): | 311 def parameterAttributes(self): |
312 raise AbstractFunction() | 312 raise AbstractFunction() |
313 | 313 |
314 def updateStart(self): pass | 314 def updateStart(self,training_set): |
315 pass | |
315 | 316 |
316 def updateEnd(self): | 317 def updateEnd(self): |
317 self.setAttributes(self.updateEndOutputAttributes(), | 318 self.setAttributes(self.updateEndOutputAttributes(), |
318 self.update_end_function | 319 self.update_end_function |
319 (self.names2attributes(self.updateEndInputAttributes()))) | 320 (self.names2attributes(self.updateEndInputAttributes()))) |
341 @todo check if some of the learner attributes are actually SPECIFIED | 342 @todo check if some of the learner attributes are actually SPECIFIED |
342 in as attributes of the training_set. | 343 in as attributes of the training_set. |
343 """ | 344 """ |
344 self.updateStart(training_set) | 345 self.updateStart(training_set) |
345 stop=False | 346 stop=False |
347 if hasattr(self,'_minibatch_size') and self._minibatch_size: | |
348 minibatch_size=self._minibatch_size | |
349 else: | |
350 minibatch_size=min(100,len(training_set)) | |
346 while not stop: | 351 while not stop: |
347 if train_stats_collector: | 352 if train_stats_collector: |
348 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch | 353 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch |
349 for minibatch in training_set.minibatches(self.training_set_input_fields, | 354 for minibatch in training_set.minibatches(minibatch_size=minibatch_size): |
350 minibatch_size=self.minibatch_size): | 355 self.updateMinibatch(minibatch) |
351 self.update_minibatch(minibatch) | |
352 if train_stats_collector: | 356 if train_stats_collector: |
353 minibatch_set = minibatch.examples() | 357 minibatch_set = minibatch.examples() |
354 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) | 358 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) |
355 train_stats_collector.update(minibatch_set) | 359 train_stats_collector.update(minibatch_set) |
356 stop = self.isLastEpoch() | 360 stop = self.isLastEpoch() |
388 | 392 |
389 def updateMinibatchInputAttributes(self): | 393 def updateMinibatchInputAttributes(self): |
390 return self.parameterAttributes() | 394 return self.parameterAttributes() |
391 | 395 |
392 def updateMinibatchOutputAttributes(self): | 396 def updateMinibatchOutputAttributes(self): |
393 return ["_new"+name for name in self.parameterAttributes()] | 397 return ["new_"+name for name in self.parameterAttributes()] |
394 | 398 |
395 def updateEndInputAttributes(self): | 399 def updateEndInputAttributes(self): |
396 return self.parameterAttributes() | 400 return self.parameterAttributes() |
397 | 401 |
398 | 402 |