Mercurial > pylearn
comparison learner.py @ 134:3f4e5c9bdc5e
Fixes to ApplyFunctionDataSet and other things to make learner and mlp work
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Fri, 09 May 2008 17:38:57 -0400 |
parents | b4657441dd65 |
children | 0d8e721cc63c b7ca3545186b |
comparison
equal
deleted
inserted
replaced
133:b4657441dd65 | 134:3f4e5c9bdc5e |
---|---|
1 | 1 |
2 from dataset import AttributesHolder,AbstractFunction | 2 from dataset import AttributesHolder,AbstractFunction,ApplyFunctionDataSet,DataSet,CachedDataSet |
3 import compile | 3 import theano |
4 from theano import compile | |
4 from theano import tensor as t | 5 from theano import tensor as t |
5 | 6 |
6 class Learner(AttributesHolder): | 7 class Learner(AttributesHolder): |
7 """ | 8 """ |
8 Base class for learning algorithms, provides an interface | 9 Base class for learning algorithms, provides an interface |
130 """ | 131 """ |
131 Return a list with the values of the learner's attributes (or optionally, a deep copy). | 132 Return a list with the values of the learner's attributes (or optionally, a deep copy). |
132 """ | 133 """ |
133 return self.names2attributes(self.attributeNames(),return_copy) | 134 return self.names2attributes(self.attributeNames(),return_copy) |
134 | 135 |
135 def names2attributes(self,names,return_copy=False): | 136 def names2attributes(self,names): |
136 """ | 137 """ |
137 Private helper function that maps a list of attribute names to a list | 138 Private helper function that maps a list of attribute names to a list |
138 of (optionally copies) values of attributes. | 139 of (optionally copies) values of attributes. |
139 """ | 140 """ |
140 if return_copy: | 141 res=[] |
141 return [copy.deepcopy(self.__getattribute__(name).data) for name in names] | 142 for name in names: |
142 else: | 143 assert name in names |
143 return [self.__getattribute__(name).data for name in names] | 144 res.append(self.__getattribute__(name)) |
144 | 145 return res |
145 def updateInputAttributes(self): | |
146 """ | |
147 A subset of self.attributeNames() which are the names of attributes needed by update() in order | |
148 to do its work. | |
149 """ | |
150 raise AbstractFunction() | |
151 | 146 |
152 def useInputAttributes(self): | 147 def useInputAttributes(self): |
153 """ | 148 """ |
154 A subset of self.attributeNames() which are the names of attributes needed by use() in order | 149 A subset of self.attributeNames() which are the names of attributes needed by use() in order |
155 to do its work. | 150 to do its work. |
156 """ | 151 """ |
157 raise AbstractFunction() | 152 raise AbstractFunction() |
158 | |
159 def updateOutputAttributes(self): | |
160 """ | |
161 A subset of self.attributeNames() which are the names of attributes modified/created by update() in order | |
162 to do its work. | |
163 | |
164 By default these are inferred from the various update output attributes: | |
165 """ | |
166 return ["parameters"] + self.updateMinibatchOutputAttributes() + self.updateEndOutputAttributes() | |
167 | 153 |
168 def useOutputAttributes(self): | 154 def useOutputAttributes(self): |
169 """ | 155 """ |
170 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order | 156 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order |
171 to do its work. | 157 to do its work. |
208 dependant de Theano | 194 dependant de Theano |
209 """ | 195 """ |
210 | 196 |
211 def __init__(self): | 197 def __init__(self): |
212 Learner.__init__(self) | 198 Learner.__init__(self) |
199 self.use_functions_dictionary={} | |
213 | 200 |
214 def defaultOutputFields(self, input_fields): | 201 def defaultOutputFields(self, input_fields): |
215 """ | 202 """ |
216 Return a default list of output field names (to put in the output dataset). | 203 Return a default list of output field names (to put in the output dataset). |
217 This will be used when None are provided (as output_fields) by the caller of the 'use' method. | 204 This will be used when None are provided (as output_fields) by the caller of the 'use' method. |
230 if stats_collector: | 217 if stats_collector: |
231 stats_collector_inputs = stats_collector.input2UpdateAttributes() | 218 stats_collector_inputs = stats_collector.input2UpdateAttributes() |
232 for attribute in stats_collector_inputs: | 219 for attribute in stats_collector_inputs: |
233 if attribute not in input_fields: | 220 if attribute not in input_fields: |
234 output_fields.append(attribute) | 221 output_fields.append(attribute) |
235 key = (input_fields,output_fields) | 222 key = (tuple(input_fields),tuple(output_fields)) |
236 if key not in self.use_functions_dictionary: | 223 if key not in self.use_functions_dictionary: |
237 use_input_attributes = self.useInputAttributes() | 224 use_input_attributes = self.useInputAttributes() |
238 use_output_attributes = self.useOutputAttributes() | 225 use_output_attributes = self.useOutputAttributes() |
239 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes), | 226 complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes), |
240 self.names2OpResults(output_fields+use_output_attributes)) | 227 self.names2OpResults(output_fields+use_output_attributes)) |
241 def f(*input_field_values): | 228 def f(*input_field_values): |
242 input_attribute_values = self.names2attributes(use_input_attributes) | 229 input_attribute_values = self.names2attributes(use_input_attributes) |
243 results = complete_f(*(input_field_values + input_attribute_values)) | 230 results = complete_f(*(list(input_field_values) + input_attribute_values)) |
244 output_field_values = results[0:len(output_fields)] | 231 output_field_values = results[0:len(output_fields)] |
245 output_attribute_values = results[len(output_fields):len(results)] | 232 output_attribute_values = results[len(output_fields):len(results)] |
246 if use_output_attributes: | 233 if use_output_attributes: |
247 self.setAttributes(use_output_attributes,output_attribute_values) | 234 self.setAttributes(use_output_attributes,output_attribute_values) |
248 return output_field_values | 235 return output_field_values |
274 | 261 |
275 """ | 262 """ |
276 | 263 |
277 def __init__(self): | 264 def __init__(self): |
278 TLearner.__init__(self) | 265 TLearner.__init__(self) |
279 self.update_minibatch_function = compile.function | 266 self.update_minibatch_function = compile.function(self.names2OpResults(self.updateMinibatchOutputAttributes()+ |
280 (self.names2OpResults(self.updateMinibatchOutputAttributes()+ | 267 self.updateMinibatchInputFields()), |
281 self.updateMinibatchInputFields()), | 268 self.names2OpResults(self.updateMinibatchOutputAttributes())) |
282 self.names2OpResults(self.updateMinibatchOutputAttributes())) | 269 self.update_end_function = compile.function(self.names2OpResults(self.updateEndInputAttributes()), |
283 self.update_end_function = compile.function | 270 self.names2OpResults(self.updateEndOutputAttributes())) |
284 (self.names2OpResults(self.updateEndInputAttributes()), | |
285 self.names2OpResults(self.updateEndOutputAttributes())) | |
286 | 271 |
287 def allocate(self, minibatch): | 272 def allocate(self, minibatch): |
288 """ | 273 """ |
289 This function is called at the beginning of each L{updateMinibatch} | 274 This function is called at the beginning of each L{updateMinibatch} |
290 and should be used to check that all required attributes have been | 275 and should be used to check that all required attributes have been |
314 def updateStart(self,training_set): | 299 def updateStart(self,training_set): |
315 pass | 300 pass |
316 | 301 |
317 def updateEnd(self): | 302 def updateEnd(self): |
318 self.setAttributes(self.updateEndOutputAttributes(), | 303 self.setAttributes(self.updateEndOutputAttributes(), |
319 self.update_end_function | 304 self.update_end_function(*self.names2attributes(self.updateEndInputAttributes()))) |
320 (self.names2attributes(self.updateEndInputAttributes()))) | |
321 self.parameters = self.names2attributes(self.parameterAttributes()) | 305 self.parameters = self.names2attributes(self.parameterAttributes()) |
322 | 306 |
323 def updateMinibatch(self,minibatch): | 307 def updateMinibatch(self,minibatch): |
324 # make sure all required fields are allocated and initialized | 308 # make sure all required fields are allocated and initialized |
325 self.allocate(minibatch) | 309 self.allocate(minibatch) |
310 input_attributes = self.names2attributes(self.updateMinibatchInputAttributes()) | |
311 input_fields = minibatch(*self.updateMinibatchInputFields()) | |
326 self.setAttributes(self.updateMinibatchOutputAttributes(), | 312 self.setAttributes(self.updateMinibatchOutputAttributes(), |
327 # concatenate the attribute values and field values and then apply update fn | 313 # concatenate the attribute values and field values and then apply update fn |
328 self.update_minibatch_function(*(self.names2attributes | 314 self.update_minibatch_function(*(input_attributes+input_fields))) |
329 (self.updateMinibatchInputAttributes())) | |
330 + minibatch(self.updateMinibatchInputFields()))) | |
331 | 315 |
332 def isLastEpoch(self): | 316 def isLastEpoch(self): |
333 """ | 317 """ |
334 This method is called at the end of each epoch (cycling over the training set). | 318 This method is called at the end of each epoch (cycling over the training set). |
335 It returns a boolean to indicate if this is the last epoch. | 319 It returns a boolean to indicate if this is the last epoch. |
385 self.setAttributes(new_params_names, | 369 self.setAttributes(new_params_names, |
386 [t.add_inplace(param,self._learning_rate*t.grad(loss,param)) | 370 [t.add_inplace(param,self._learning_rate*t.grad(loss,param)) |
387 for param in old_params]) | 371 for param in old_params]) |
388 MinibatchUpdatesTLearner.__init__(self) | 372 MinibatchUpdatesTLearner.__init__(self) |
389 | 373 |
374 | |
375 def namesOfAttributesToComputeOutputs(self,output_names): | |
376 """ | |
377 The output_names are attribute names (not the corresponding Result names, which have leading _). | |
378 Return the corresponding input names | |
379 """ | |
380 all_inputs = t.gof.graph.inputs(self.names2OpResults(output_names)) | |
381 # remove constants and leading '_' in name | |
382 | |
383 return [r.name for r in all_inputs if isinstance(r,theano.Result) and \ | |
384 not isinstance(r,theano.Constant) and not isinstance(r,theano.Value)] | |
385 #inputs = [] | |
386 #for r in all_inputs: | |
387 # if isinstance(r,theano.Result) and \ | |
388 # not isinstance(r,theano.Constant) and not isinstance(r,theano.Value): | |
389 # inputs.append(r.name) | |
390 #return inputs | |
391 | |
390 def isLastEpoch(self): | 392 def isLastEpoch(self): |
391 return self.truly_online | 393 return self.truly_online |
392 | 394 |
393 def updateMinibatchInputAttributes(self): | 395 def updateMinibatchInputAttributes(self): |
394 return self.parameterAttributes() | 396 return self.parameterAttributes() |
395 | 397 |
396 def updateMinibatchOutputAttributes(self): | 398 def updateMinibatchOutputAttributes(self): |
397 return ["new_"+name for name in self.parameterAttributes()] | 399 return ["new_"+name for name in self.parameterAttributes()] |
398 | 400 |
399 def updateEndInputAttributes(self): | 401 def updateEndInputAttributes(self): |
402 return self.namesOfAttributesToComputeOutputs(self.updateEndOutputAttributes()) | |
403 | |
404 def useInputAttributes(self): | |
400 return self.parameterAttributes() | 405 return self.parameterAttributes() |
401 | 406 |
402 | 407 def useOutputAttributes(self): |
408 return [] | |
409 |