comparison learner.py @ 133:b4657441dd65

Corrected typos
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Fri, 09 May 2008 13:38:54 -0400
parents f6505ec32dc3
children 3f4e5c9bdc5e
comparison
equal deleted inserted replaced
132:f6505ec32dc3 133:b4657441dd65
45 """ 45 """
46 Train a learner from scratch using the provided training set, 46 Train a learner from scratch using the provided training set,
47 and return the learned function. 47 and return the learned function.
48 """ 48 """
49 self.forget() 49 self.forget()
50 return self.update(learning_task,train_stats_collector) 50 return self.update(training_set,train_stats_collector)
51 51
52 def use(self,input_dataset,output_fieldnames=None, 52 def use(self,input_dataset,output_fieldnames=None,
53 test_stats_collector=None,copy_inputs=True, 53 test_stats_collector=None,copy_inputs=True,
54 put_stats_in_output_dataset=True, 54 put_stats_in_output_dataset=True,
55 output_attributes=[]): 55 output_attributes=[]):
252 def names2OpResults(self,names): 252 def names2OpResults(self,names):
253 """ 253 """
254 Private helper function that maps a list of attribute names to a list 254 Private helper function that maps a list of attribute names to a list
255 of corresponding Op Results (with the same name but with a '_' prefix). 255 of corresponding Op Results (with the same name but with a '_' prefix).
256 """ 256 """
257 return [self.__getattribute__('_'+name).data for name in names] 257 return [self.__getattribute__('_'+name) for name in names]
258 258
259 259
260 class MinibatchUpdatesTLearner(TLearner): 260 class MinibatchUpdatesTLearner(TLearner):
261 """ 261 """
262 This adds to L{TLearner} a 262 This adds to L{TLearner} a
309 raise AbstractFunction() 309 raise AbstractFunction()
310 310
311 def parameterAttributes(self): 311 def parameterAttributes(self):
312 raise AbstractFunction() 312 raise AbstractFunction()
313 313
314 def updateStart(self): pass 314 def updateStart(self,training_set):
315 pass
315 316
316 def updateEnd(self): 317 def updateEnd(self):
317 self.setAttributes(self.updateEndOutputAttributes(), 318 self.setAttributes(self.updateEndOutputAttributes(),
318 self.update_end_function 319 self.update_end_function
319 (self.names2attributes(self.updateEndInputAttributes()))) 320 (self.names2attributes(self.updateEndInputAttributes())))
341 @todo check if some of the learner attributes are actually SPECIFIED 342 @todo check if some of the learner attributes are actually SPECIFIED
342 in as attributes of the training_set. 343 in as attributes of the training_set.
343 """ 344 """
344 self.updateStart(training_set) 345 self.updateStart(training_set)
345 stop=False 346 stop=False
347 if hasattr(self,'_minibatch_size') and self._minibatch_size:
348 minibatch_size=self._minibatch_size
349 else:
350 minibatch_size=min(100,len(training_set))
346 while not stop: 351 while not stop:
347 if train_stats_collector: 352 if train_stats_collector:
348 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch 353 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch
349 for minibatch in training_set.minibatches(self.training_set_input_fields, 354 for minibatch in training_set.minibatches(minibatch_size=minibatch_size):
350 minibatch_size=self.minibatch_size): 355 self.updateMinibatch(minibatch)
351 self.update_minibatch(minibatch)
352 if train_stats_collector: 356 if train_stats_collector:
353 minibatch_set = minibatch.examples() 357 minibatch_set = minibatch.examples()
354 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) 358 minibatch_set.setAttributes(self.attributeNames(),self.attributes())
355 train_stats_collector.update(minibatch_set) 359 train_stats_collector.update(minibatch_set)
356 stop = self.isLastEpoch() 360 stop = self.isLastEpoch()
388 392
389 def updateMinibatchInputAttributes(self): 393 def updateMinibatchInputAttributes(self):
390 return self.parameterAttributes() 394 return self.parameterAttributes()
391 395
392 def updateMinibatchOutputAttributes(self): 396 def updateMinibatchOutputAttributes(self):
393 return ["_new"+name for name in self.parameterAttributes()] 397 return ["new_"+name for name in self.parameterAttributes()]
394 398
395 def updateEndInputAttributes(self): 399 def updateEndInputAttributes(self):
396 return self.parameterAttributes() 400 return self.parameterAttributes()
397 401
398 402