comparison learner.py @ 134:3f4e5c9bdc5e

Fixes to ApplyFunctionDataSet and other things to make learner and mlp work
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Fri, 09 May 2008 17:38:57 -0400
parents b4657441dd65
children 0d8e721cc63c b7ca3545186b
comparison
equal deleted inserted replaced
133:b4657441dd65 134:3f4e5c9bdc5e
1 1
2 from dataset import AttributesHolder,AbstractFunction 2 from dataset import AttributesHolder,AbstractFunction,ApplyFunctionDataSet,DataSet,CachedDataSet
3 import compile 3 import theano
4 from theano import compile
4 from theano import tensor as t 5 from theano import tensor as t
5 6
6 class Learner(AttributesHolder): 7 class Learner(AttributesHolder):
7 """ 8 """
8 Base class for learning algorithms, provides an interface 9 Base class for learning algorithms, provides an interface
130 """ 131 """
131 Return a list with the values of the learner's attributes (or optionally, a deep copy). 132 Return a list with the values of the learner's attributes (or optionally, a deep copy).
132 """ 133 """
133 return self.names2attributes(self.attributeNames(),return_copy) 134 return self.names2attributes(self.attributeNames(),return_copy)
134 135
135 def names2attributes(self,names,return_copy=False): 136 def names2attributes(self,names):
136 """ 137 """
137 Private helper function that maps a list of attribute names to a list 138 Private helper function that maps a list of attribute names to a list
138 of (optionally copies) values of attributes. 139 of (optionally copies) values of attributes.
139 """ 140 """
140 if return_copy: 141 res=[]
141 return [copy.deepcopy(self.__getattribute__(name).data) for name in names] 142 for name in names:
142 else: 143 assert name in names
143 return [self.__getattribute__(name).data for name in names] 144 res.append(self.__getattribute__(name))
144 145 return res
145 def updateInputAttributes(self):
146 """
147 A subset of self.attributeNames() which are the names of attributes needed by update() in order
148 to do its work.
149 """
150 raise AbstractFunction()
151 146
152 def useInputAttributes(self): 147 def useInputAttributes(self):
153 """ 148 """
154 A subset of self.attributeNames() which are the names of attributes needed by use() in order 149 A subset of self.attributeNames() which are the names of attributes needed by use() in order
155 to do its work. 150 to do its work.
156 """ 151 """
157 raise AbstractFunction() 152 raise AbstractFunction()
158
159 def updateOutputAttributes(self):
160 """
161 A subset of self.attributeNames() which are the names of attributes modified/created by update() in order
162 to do its work.
163
164 By default these are inferred from the various update output attributes:
165 """
166 return ["parameters"] + self.updateMinibatchOutputAttributes() + self.updateEndOutputAttributes()
167 153
168 def useOutputAttributes(self): 154 def useOutputAttributes(self):
169 """ 155 """
170 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order 156 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order
171 to do its work. 157 to do its work.
208 dependant de Theano 194 dependant de Theano
209 """ 195 """
210 196
211 def __init__(self): 197 def __init__(self):
212 Learner.__init__(self) 198 Learner.__init__(self)
199 self.use_functions_dictionary={}
213 200
214 def defaultOutputFields(self, input_fields): 201 def defaultOutputFields(self, input_fields):
215 """ 202 """
216 Return a default list of output field names (to put in the output dataset). 203 Return a default list of output field names (to put in the output dataset).
217 This will be used when None are provided (as output_fields) by the caller of the 'use' method. 204 This will be used when None are provided (as output_fields) by the caller of the 'use' method.
230 if stats_collector: 217 if stats_collector:
231 stats_collector_inputs = stats_collector.input2UpdateAttributes() 218 stats_collector_inputs = stats_collector.input2UpdateAttributes()
232 for attribute in stats_collector_inputs: 219 for attribute in stats_collector_inputs:
233 if attribute not in input_fields: 220 if attribute not in input_fields:
234 output_fields.append(attribute) 221 output_fields.append(attribute)
235 key = (input_fields,output_fields) 222 key = (tuple(input_fields),tuple(output_fields))
236 if key not in self.use_functions_dictionary: 223 if key not in self.use_functions_dictionary:
237 use_input_attributes = self.useInputAttributes() 224 use_input_attributes = self.useInputAttributes()
238 use_output_attributes = self.useOutputAttributes() 225 use_output_attributes = self.useOutputAttributes()
239 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes), 226 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes),
240 self.names2OpResults(output_fields+use_output_attributes)) 227 self.names2OpResults(output_fields+use_output_attributes))
241 def f(*input_field_values): 228 def f(*input_field_values):
242 input_attribute_values = self.names2attributes(use_input_attributes) 229 input_attribute_values = self.names2attributes(use_input_attributes)
243 results = complete_f(*(input_field_values + input_attribute_values)) 230 results = complete_f(*(list(input_field_values) + input_attribute_values))
244 output_field_values = results[0:len(output_fields)] 231 output_field_values = results[0:len(output_fields)]
245 output_attribute_values = results[len(output_fields):len(results)] 232 output_attribute_values = results[len(output_fields):len(results)]
246 if use_output_attributes: 233 if use_output_attributes:
247 self.setAttributes(use_output_attributes,output_attribute_values) 234 self.setAttributes(use_output_attributes,output_attribute_values)
248 return output_field_values 235 return output_field_values
274 261
275 """ 262 """
276 263
277 def __init__(self): 264 def __init__(self):
278 TLearner.__init__(self) 265 TLearner.__init__(self)
279 self.update_minibatch_function = compile.function 266 self.update_minibatch_function = compile.function(self.names2OpResults(self.updateMinibatchOutputAttributes()+
280 (self.names2OpResults(self.updateMinibatchOutputAttributes()+ 267 self.updateMinibatchInputFields()),
281 self.updateMinibatchInputFields()), 268 self.names2OpResults(self.updateMinibatchOutputAttributes()))
282 self.names2OpResults(self.updateMinibatchOutputAttributes())) 269 self.update_end_function = compile.function(self.names2OpResults(self.updateEndInputAttributes()),
283 self.update_end_function = compile.function 270 self.names2OpResults(self.updateEndOutputAttributes()))
284 (self.names2OpResults(self.updateEndInputAttributes()),
285 self.names2OpResults(self.updateEndOutputAttributes()))
286 271
287 def allocate(self, minibatch): 272 def allocate(self, minibatch):
288 """ 273 """
289 This function is called at the beginning of each L{updateMinibatch} 274 This function is called at the beginning of each L{updateMinibatch}
290 and should be used to check that all required attributes have been 275 and should be used to check that all required attributes have been
314 def updateStart(self,training_set): 299 def updateStart(self,training_set):
315 pass 300 pass
316 301
317 def updateEnd(self): 302 def updateEnd(self):
318 self.setAttributes(self.updateEndOutputAttributes(), 303 self.setAttributes(self.updateEndOutputAttributes(),
319 self.update_end_function 304 self.update_end_function(*self.names2attributes(self.updateEndInputAttributes())))
320 (self.names2attributes(self.updateEndInputAttributes())))
321 self.parameters = self.names2attributes(self.parameterAttributes()) 305 self.parameters = self.names2attributes(self.parameterAttributes())
322 306
323 def updateMinibatch(self,minibatch): 307 def updateMinibatch(self,minibatch):
324 # make sure all required fields are allocated and initialized 308 # make sure all required fields are allocated and initialized
325 self.allocate(minibatch) 309 self.allocate(minibatch)
310 input_attributes = self.names2attributes(self.updateMinibatchInputAttributes())
311 input_fields = minibatch(*self.updateMinibatchInputFields())
326 self.setAttributes(self.updateMinibatchOutputAttributes(), 312 self.setAttributes(self.updateMinibatchOutputAttributes(),
327 # concatenate the attribute values and field values and then apply update fn 313 # concatenate the attribute values and field values and then apply update fn
328 self.update_minibatch_function(*(self.names2attributes 314 self.update_minibatch_function(*(input_attributes+input_fields)))
329 (self.updateMinibatchInputAttributes()))
330 + minibatch(self.updateMinibatchInputFields())))
331 315
332 def isLastEpoch(self): 316 def isLastEpoch(self):
333 """ 317 """
334 This method is called at the end of each epoch (cycling over the training set). 318 This method is called at the end of each epoch (cycling over the training set).
335 It returns a boolean to indicate if this is the last epoch. 319 It returns a boolean to indicate if this is the last epoch.
385 self.setAttributes(new_params_names, 369 self.setAttributes(new_params_names,
386 [t.add_inplace(param,self._learning_rate*t.grad(loss,param)) 370 [t.add_inplace(param,self._learning_rate*t.grad(loss,param))
387 for param in old_params]) 371 for param in old_params])
388 MinibatchUpdatesTLearner.__init__(self) 372 MinibatchUpdatesTLearner.__init__(self)
389 373
374
375 def namesOfAttributesToComputeOutputs(self,output_names):
376 """
377 The output_names are attribute names (not the corresponding Result names, which have leading _).
378 Return the corresponding input names
379 """
380 all_inputs = t.gof.graph.inputs(self.names2OpResults(output_names))
381 # remove constants and leading '_' in name
382
383 return [r.name for r in all_inputs if isinstance(r,theano.Result) and \
384 not isinstance(r,theano.Constant) and not isinstance(r,theano.Value)]
385 #inputs = []
386 #for r in all_inputs:
387 # if isinstance(r,theano.Result) and \
388 # not isinstance(r,theano.Constant) and not isinstance(r,theano.Value):
389 # inputs.append(r.name)
390 #return inputs
391
390 def isLastEpoch(self): 392 def isLastEpoch(self):
391 return self.truly_online 393 return self.truly_online
392 394
393 def updateMinibatchInputAttributes(self): 395 def updateMinibatchInputAttributes(self):
394 return self.parameterAttributes() 396 return self.parameterAttributes()
395 397
396 def updateMinibatchOutputAttributes(self): 398 def updateMinibatchOutputAttributes(self):
397 return ["new_"+name for name in self.parameterAttributes()] 399 return ["new_"+name for name in self.parameterAttributes()]
398 400
399 def updateEndInputAttributes(self): 401 def updateEndInputAttributes(self):
402 return self.namesOfAttributesToComputeOutputs(self.updateEndOutputAttributes())
403
404 def useInputAttributes(self):
400 return self.parameterAttributes() 405 return self.parameterAttributes()
401 406
402 407 def useOutputAttributes(self):
408 return []
409