comparison learner.py @ 128:ee5507af2c60

minor edits
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Wed, 07 May 2008 20:51:24 -0400
parents 4efe6d36c061
children 4c2280edcaf5 3d8e40e7ed18
comparison
equal deleted inserted replaced
127:f959ad58facc 128:ee5507af2c60
45 and return the learned function. 45 and return the learned function.
46 """ 46 """
47 self.forget() 47 self.forget()
48 return self.update(learning_task,train_stats_collector) 48 return self.update(learning_task,train_stats_collector)
49 49
50 def use(self,input_dataset,output_fields=None,copy_inputs=True): 50 def use(self,input_dataset,output_fieldnames=None,
51 """Once a Learner has been trained by one or more call to 'update', it can 51 test_stats_collector=None,copy_inputs=True,
52 be used with one or more calls to 'use'. The argument is a DataSet (possibly 52 put_stats_in_output_dataset=True,
53 containing a single example) and the result is a DataSet of the same length. 53 output_attributes=[]):
54 If output_fields is specified, it may be use to indicate which fields should 54 """
55 Once a Learner has been trained by one or more call to 'update', it can
56 be used with one or more calls to 'use'. The argument is an input DataSet (possibly
57 containing a single example) and the result is an output DataSet of the same length.
58 If output_fieldnames is specified, it may be use to indicate which fields should
55 be constructed in the output DataSet (for example ['output','classification_error']). 59 be constructed in the output DataSet (for example ['output','classification_error']).
60 Otherwise, self.defaultOutputFields is called to choose the output fields.
56 Optionally, if copy_inputs, the input fields (of the input_dataset) can be made 61 Optionally, if copy_inputs, the input fields (of the input_dataset) can be made
57 visible in the output DataSet returned by this method. 62 visible in the output DataSet returned by this method.
58 """ 63 Optionally, attributes of the learner can be copied in the output dataset,
59 raise AbstractFunction() 64 and statistics computed by the stats collector also put in the output dataset.
60 65 Note the distinction between fields (which are example-wise quantities, e.g. 'input')
66 and attributes (which are not, e.g. 'regularization_term').
67
68 We provide here a default implementation that does all this using
69 a sub-class defined method: minibatchwiseUseFunction.
70
71 @todo check if some of the learner attributes are actually SPECIFIED
72 as attributes of the input_dataset, and if so use their values instead
73 of the ones in the learner.
74
75 The learner tries to compute in the output dataset the output fields specified.
76 If None is specified then self.defaultOutputFields(input_dataset.fieldNames())
77 is called to determine the output fields.
78
79 Attributes of the learner can also optionally be copied into the output dataset.
80 If output_attributes is None then all of the attributes in self.AttributeNames()
81 are copied in the output dataset, but if it is [] (the default), then none are copied.
82 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames())
83 are also copied into the output dataset attributes.
84 """
85 minibatchwise_use_function = self.minibatchwiseUseFunction(input_dataset.fieldNames(),
86 output_fieldnames,
87 test_stats_collector)
88 virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
89 minibatchwise_use_function,
90 True,DataSet.numpy_vstack,
91 DataSet.numpy_hstack)
92 # actually force the computation
93 output_dataset = CachedDataSet(virtual_output_dataset,True)
94 if copy_inputs:
95 output_dataset = input_dataset | output_dataset
96 # copy the wanted attributes in the dataset
97 if output_attributes is None:
98 output_attributes = self.attributeNames()
99 if output_attributes:
100 assert set(attribute_names) <= set(self.attributeNames())
101 output_dataset.setAttributes(output_attributes,
102 self.names2attributes(output_attributes,return_copy=True))
103 if test_stats_collector:
104 test_stats_collector.update(output_dataset)
105 if put_stats_in_output_dataset:
106 output_dataset.setAttributes(test_stats_collector.attributeNames(),
107 test_stats_collector.attributes())
108 return output_dataset
109
110 def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector):
111 """
112 Returns a function that can map the given input fields to the given output fields
113 and to the attributes that the stats collector needs for its computation.
114 That function is expected to operate on minibatches.
115 The function returned makes use of the self.useInputAttributes() and
116 sets the attributes specified by self.useOutputAttributes().
117 """
61 def attributeNames(self): 118 def attributeNames(self):
62 """ 119 """
63 A Learner may have attributes that it wishes to export to other objects. To automate 120 A Learner may have attributes that it wishes to export to other objects. To automate
64 such export, sub-classes should define here the names (list of strings) of these attributes. 121 such export, sub-classes should define here the names (list of strings) of these attributes.
65 122
66 @todo By default, attributeNames looks for all dictionary entries whose name does not start with _. 123 @todo By default, attributeNames looks for all dictionary entries whose name does not start with _.
67 """ 124 """
68 return [] 125 return []
126
127 def attributes(self,return_copy=False):
128 """
129 Return a list with the values of the learner's attributes (or optionally, a deep copy).
130 """
131 return self.names2attributes(self.attributeNames(),return_copy)
132
133 def names2attributes(self,names,return_copy=False):
134 """
135 Private helper function that maps a list of attribute names to a list
136 of (optionally copies) values of attributes.
137 """
138 if return_copy:
139 return [copy.deepcopy(self.__getattr__(name).data) for name in names]
140 else:
141 return [self.__getattr__(name).data for name in names]
69 142
70 def updateInputAttributes(self): 143 def updateInputAttributes(self):
71 """ 144 """
72 A subset of self.attributeNames() which are the names of attributes needed by update() in order 145 A subset of self.attributeNames() which are the names of attributes needed by update() in order
73 to do its work. 146 to do its work.
143 This may involve looking at the input_fields (names) available in the 216 This may involve looking at the input_fields (names) available in the
144 input_dataset. 217 input_dataset.
145 """ 218 """
146 raise AbstractFunction() 219 raise AbstractFunction()
147 220
148 def allocate(self, minibatch): 221 def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector):
149 """ 222 """
150 This function is called at the beginning of each updateMinibatch 223 Implement minibatchwiseUseFunction by exploiting Theano compilation
151 and should be used to check that all required attributes have been 224 and the expression graph defined by a sub-class constructor.
152 allocated and initialized (usually this function calls forget()
153 when it has to do an initialization).
154 """
155 raise AbstractFunction()
156
157 def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
158 """
159 Private helper function called by the generic TLearner.use. It returns a function
160 that can map the given input fields to the given output fields (along with the
161 attributes that the stats collector needs for its computation. The function
162 called also automatically makes use of the self.useInputAttributes() and
163 sets the self.useOutputAttributes().
164 """ 225 """
165 if not output_fields: 226 if not output_fields:
166 output_fields = self.defaultOutputFields(input_fields) 227 output_fields = self.defaultOutputFields(input_fields)
167 if stats_collector: 228 if stats_collector:
168 stats_collector_inputs = stats_collector.input2UpdateAttributes() 229 stats_collector_inputs = stats_collector.input2UpdateAttributes()
184 self.setAttributes(use_output_attributes,output_attribute_values) 245 self.setAttributes(use_output_attributes,output_attribute_values)
185 return output_field_values 246 return output_field_values
186 self.use_functions_dictionary[key]=f 247 self.use_functions_dictionary[key]=f
187 return self.use_functions_dictionary[key] 248 return self.use_functions_dictionary[key]
188 249
189 def attributes(self,return_copy=False):
190 """
191 Return a list with the values of the learner's attributes (or optionally, a deep copy).
192 """
193 return self.names2attributes(self.attributeNames(),return_copy)
194
195 def names2attributes(self,names,return_copy=False):
196 """
197 Private helper function that maps a list of attribute names to a list
198 of (optionally copies) values of attributes.
199 """
200 if return_copy:
201 return [copy.deepcopy(self.__getattr__(name).data) for name in names]
202 else:
203 return [self.__getattr__(name).data for name in names]
204
205 def names2OpResults(self,names): 250 def names2OpResults(self,names):
206 """ 251 """
207 Private helper function that maps a list of attribute names to a list 252 Private helper function that maps a list of attribute names to a list
208 of corresponding Op Results (with the same name but with a '_' prefix). 253 of corresponding Op Results (with the same name but with a '_' prefix).
209 """ 254 """
210 return [self.__getattr__('_'+name).data for name in names] 255 return [self.__getattr__('_'+name).data for name in names]
211
212 def use(self,input_dataset,output_fieldnames=None,output_attributes=[],
213 test_stats_collector=None,copy_inputs=True, put_stats_in_output_dataset=True):
214 """
215 The learner tries to compute in the output dataset the output fields specified
216
217 @todo check if some of the learner attributes are actually SPECIFIED
218 as attributes of the input_dataset, and if so use their values instead
219 of the ones in the learner.
220
221 The learner tries to compute in the output dataset the output fields specified.
222 If None is specified then self.defaultOutputFields(input_dataset.fieldNames())
223 is called to determine the output fields.
224
225 Attributes of the learner can also optionally be copied into the output dataset.
226 If output_attributes is None then all of the attributes in self.AttributeNames()
227 are copied in the output dataset, but if it is [] (the default), then none are copied.
228 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames())
229 are also copied into the output dataset attributes.
230 """
231 minibatchwise_use_function = self.minibatchwise_use_functions(input_dataset.fieldNames(),
232 output_fieldnames,
233 test_stats_collector)
234 virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
235 minibatchwise_use_function,
236 True,DataSet.numpy_vstack,
237 DataSet.numpy_hstack)
238 # actually force the computation
239 output_dataset = CachedDataSet(virtual_output_dataset,True)
240 if copy_inputs:
241 output_dataset = input_dataset | output_dataset
242 # copy the wanted attributes in the dataset
243 if output_attributes is None:
244 output_attributes = self.attributeNames()
245 if output_attributes:
246 assert set(attribute_names) <= set(self.attributeNames())
247 output_dataset.setAttributes(output_attributes,
248 self.names2attributes(output_attributes,return_copy=True))
249 if test_stats_collector:
250 test_stats_collector.update(output_dataset)
251 if put_stats_in_output_dataset:
252 output_dataset.setAttributes(test_stats_collector.attributeNames(),
253 test_stats_collector.attributes())
254 return output_dataset
255 256
256 257
257 class MinibatchUpdatesTLearner(TLearner): 258 class MinibatchUpdatesTLearner(TLearner):
258 """ 259 """
259 This adds to TLearner a 260 This adds to TLearner a
279 self.names2OpResults(self.updateMinibatchOutputAttributes())) 280 self.names2OpResults(self.updateMinibatchOutputAttributes()))
280 self.update_end_function = compile.function 281 self.update_end_function = compile.function
281 (self.names2OpResults(self.updateEndInputAttributes()), 282 (self.names2OpResults(self.updateEndInputAttributes()),
282 self.names2OpResults(self.updateEndOutputAttributes())) 283 self.names2OpResults(self.updateEndOutputAttributes()))
283 284
285 def allocate(self, minibatch):
286 """
287 This function is called at the beginning of each updateMinibatch
288 and should be used to check that all required attributes have been
289 allocated and initialized (usually this function calls forget()
290 when it has to do an initialization).
291 """
292 raise AbstractFunction()
293
284 def updateMinibatchInputFields(self): 294 def updateMinibatchInputFields(self):
285 raise AbstractFunction() 295 raise AbstractFunction()
286 296
287 def updateMinibatchInputAttributes(self): 297 def updateMinibatchInputAttributes(self):
288 raise AbstractFunction() 298 raise AbstractFunction()