Mercurial > pylearn
comparison learner.py @ 128:ee5507af2c60
minor edits
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Wed, 07 May 2008 20:51:24 -0400 |
parents | 4efe6d36c061 |
children | 4c2280edcaf5 3d8e40e7ed18 |
comparison
equal
deleted
inserted
replaced
127:f959ad58facc | 128:ee5507af2c60 |
---|---|
45 and return the learned function. | 45 and return the learned function. |
46 """ | 46 """ |
47 self.forget() | 47 self.forget() |
48 return self.update(learning_task,train_stats_collector) | 48 return self.update(learning_task,train_stats_collector) |
49 | 49 |
50 def use(self,input_dataset,output_fields=None,copy_inputs=True): | 50 def use(self,input_dataset,output_fieldnames=None, |
51 """Once a Learner has been trained by one or more call to 'update', it can | 51 test_stats_collector=None,copy_inputs=True, |
52 be used with one or more calls to 'use'. The argument is a DataSet (possibly | 52 put_stats_in_output_dataset=True, |
53 containing a single example) and the result is a DataSet of the same length. | 53 output_attributes=[]): |
54 If output_fields is specified, it may be use to indicate which fields should | 54 """ |
55 Once a Learner has been trained by one or more call to 'update', it can | |
56 be used with one or more calls to 'use'. The argument is an input DataSet (possibly | |
57 containing a single example) and the result is an output DataSet of the same length. | |
58 If output_fieldnames is specified, it may be use to indicate which fields should | |
55 be constructed in the output DataSet (for example ['output','classification_error']). | 59 be constructed in the output DataSet (for example ['output','classification_error']). |
60 Otherwise, self.defaultOutputFields is called to choose the output fields. | |
56 Optionally, if copy_inputs, the input fields (of the input_dataset) can be made | 61 Optionally, if copy_inputs, the input fields (of the input_dataset) can be made |
57 visible in the output DataSet returned by this method. | 62 visible in the output DataSet returned by this method. |
58 """ | 63 Optionally, attributes of the learner can be copied in the output dataset, |
59 raise AbstractFunction() | 64 and statistics computed by the stats collector also put in the output dataset. |
60 | 65 Note the distinction between fields (which are example-wise quantities, e.g. 'input') |
66 and attributes (which are not, e.g. 'regularization_term'). | |
67 | |
68 We provide here a default implementation that does all this using | |
69 a sub-class defined method: minibatchwiseUseFunction. | |
70 | |
71 @todo check if some of the learner attributes are actually SPECIFIED | |
72 as attributes of the input_dataset, and if so use their values instead | |
73 of the ones in the learner. | |
74 | |
75 The learner tries to compute in the output dataset the output fields specified. | |
76 If None is specified then self.defaultOutputFields(input_dataset.fieldNames()) | |
77 is called to determine the output fields. | |
78 | |
79 Attributes of the learner can also optionally be copied into the output dataset. | |
80 If output_attributes is None then all of the attributes in self.AttributeNames() | |
81 are copied in the output dataset, but if it is [] (the default), then none are copied. | |
82 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) | |
83 are also copied into the output dataset attributes. | |
84 """ | |
85 minibatchwise_use_function = self.minibatchwiseUseFunction(input_dataset.fieldNames(), | |
86 output_fieldnames, | |
87 test_stats_collector) | |
88 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, | |
89 minibatchwise_use_function, | |
90 True,DataSet.numpy_vstack, | |
91 DataSet.numpy_hstack) | |
92 # actually force the computation | |
93 output_dataset = CachedDataSet(virtual_output_dataset,True) | |
94 if copy_inputs: | |
95 output_dataset = input_dataset | output_dataset | |
96 # copy the wanted attributes in the dataset | |
97 if output_attributes is None: | |
98 output_attributes = self.attributeNames() | |
99 if output_attributes: | |
100 assert set(attribute_names) <= set(self.attributeNames()) | |
101 output_dataset.setAttributes(output_attributes, | |
102 self.names2attributes(output_attributes,return_copy=True)) | |
103 if test_stats_collector: | |
104 test_stats_collector.update(output_dataset) | |
105 if put_stats_in_output_dataset: | |
106 output_dataset.setAttributes(test_stats_collector.attributeNames(), | |
107 test_stats_collector.attributes()) | |
108 return output_dataset | |
109 | |
110 def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector): | |
111 """ | |
112 Returns a function that can map the given input fields to the given output fields | |
113 and to the attributes that the stats collector needs for its computation. | |
114 That function is expected to operate on minibatches. | |
115 The function returned makes use of the self.useInputAttributes() and | |
116 sets the attributes specified by self.useOutputAttributes(). | |
117 """ | |
61 def attributeNames(self): | 118 def attributeNames(self): |
62 """ | 119 """ |
63 A Learner may have attributes that it wishes to export to other objects. To automate | 120 A Learner may have attributes that it wishes to export to other objects. To automate |
64 such export, sub-classes should define here the names (list of strings) of these attributes. | 121 such export, sub-classes should define here the names (list of strings) of these attributes. |
65 | 122 |
66 @todo By default, attributeNames looks for all dictionary entries whose name does not start with _. | 123 @todo By default, attributeNames looks for all dictionary entries whose name does not start with _. |
67 """ | 124 """ |
68 return [] | 125 return [] |
126 | |
127 def attributes(self,return_copy=False): | |
128 """ | |
129 Return a list with the values of the learner's attributes (or optionally, a deep copy). | |
130 """ | |
131 return self.names2attributes(self.attributeNames(),return_copy) | |
132 | |
133 def names2attributes(self,names,return_copy=False): | |
134 """ | |
135 Private helper function that maps a list of attribute names to a list | |
136 of (optionally copies) values of attributes. | |
137 """ | |
138 if return_copy: | |
139 return [copy.deepcopy(self.__getattr__(name).data) for name in names] | |
140 else: | |
141 return [self.__getattr__(name).data for name in names] | |
69 | 142 |
70 def updateInputAttributes(self): | 143 def updateInputAttributes(self): |
71 """ | 144 """ |
72 A subset of self.attributeNames() which are the names of attributes needed by update() in order | 145 A subset of self.attributeNames() which are the names of attributes needed by update() in order |
73 to do its work. | 146 to do its work. |
143 This may involve looking at the input_fields (names) available in the | 216 This may involve looking at the input_fields (names) available in the |
144 input_dataset. | 217 input_dataset. |
145 """ | 218 """ |
146 raise AbstractFunction() | 219 raise AbstractFunction() |
147 | 220 |
148 def allocate(self, minibatch): | 221 def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector): |
149 """ | 222 """ |
150 This function is called at the beginning of each updateMinibatch | 223 Implement minibatchwiseUseFunction by exploiting Theano compilation |
151 and should be used to check that all required attributes have been | 224 and the expression graph defined by a sub-class constructor. |
152 allocated and initialized (usually this function calls forget() | |
153 when it has to do an initialization). | |
154 """ | |
155 raise AbstractFunction() | |
156 | |
157 def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): | |
158 """ | |
159 Private helper function called by the generic TLearner.use. It returns a function | |
160 that can map the given input fields to the given output fields (along with the | |
161 attributes that the stats collector needs for its computation. The function | |
162 called also automatically makes use of the self.useInputAttributes() and | |
163 sets the self.useOutputAttributes(). | |
164 """ | 225 """ |
165 if not output_fields: | 226 if not output_fields: |
166 output_fields = self.defaultOutputFields(input_fields) | 227 output_fields = self.defaultOutputFields(input_fields) |
167 if stats_collector: | 228 if stats_collector: |
168 stats_collector_inputs = stats_collector.input2UpdateAttributes() | 229 stats_collector_inputs = stats_collector.input2UpdateAttributes() |
184 self.setAttributes(use_output_attributes,output_attribute_values) | 245 self.setAttributes(use_output_attributes,output_attribute_values) |
185 return output_field_values | 246 return output_field_values |
186 self.use_functions_dictionary[key]=f | 247 self.use_functions_dictionary[key]=f |
187 return self.use_functions_dictionary[key] | 248 return self.use_functions_dictionary[key] |
188 | 249 |
189 def attributes(self,return_copy=False): | |
190 """ | |
191 Return a list with the values of the learner's attributes (or optionally, a deep copy). | |
192 """ | |
193 return self.names2attributes(self.attributeNames(),return_copy) | |
194 | |
195 def names2attributes(self,names,return_copy=False): | |
196 """ | |
197 Private helper function that maps a list of attribute names to a list | |
198 of (optionally copies) values of attributes. | |
199 """ | |
200 if return_copy: | |
201 return [copy.deepcopy(self.__getattr__(name).data) for name in names] | |
202 else: | |
203 return [self.__getattr__(name).data for name in names] | |
204 | |
205 def names2OpResults(self,names): | 250 def names2OpResults(self,names): |
206 """ | 251 """ |
207 Private helper function that maps a list of attribute names to a list | 252 Private helper function that maps a list of attribute names to a list |
208 of corresponding Op Results (with the same name but with a '_' prefix). | 253 of corresponding Op Results (with the same name but with a '_' prefix). |
209 """ | 254 """ |
210 return [self.__getattr__('_'+name).data for name in names] | 255 return [self.__getattr__('_'+name).data for name in names] |
211 | |
212 def use(self,input_dataset,output_fieldnames=None,output_attributes=[], | |
213 test_stats_collector=None,copy_inputs=True, put_stats_in_output_dataset=True): | |
214 """ | |
215 The learner tries to compute in the output dataset the output fields specified | |
216 | |
217 @todo check if some of the learner attributes are actually SPECIFIED | |
218 as attributes of the input_dataset, and if so use their values instead | |
219 of the ones in the learner. | |
220 | |
221 The learner tries to compute in the output dataset the output fields specified. | |
222 If None is specified then self.defaultOutputFields(input_dataset.fieldNames()) | |
223 is called to determine the output fields. | |
224 | |
225 Attributes of the learner can also optionally be copied into the output dataset. | |
226 If output_attributes is None then all of the attributes in self.AttributeNames() | |
227 are copied in the output dataset, but if it is [] (the default), then none are copied. | |
228 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) | |
229 are also copied into the output dataset attributes. | |
230 """ | |
231 minibatchwise_use_function = self.minibatchwise_use_functions(input_dataset.fieldNames(), | |
232 output_fieldnames, | |
233 test_stats_collector) | |
234 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, | |
235 minibatchwise_use_function, | |
236 True,DataSet.numpy_vstack, | |
237 DataSet.numpy_hstack) | |
238 # actually force the computation | |
239 output_dataset = CachedDataSet(virtual_output_dataset,True) | |
240 if copy_inputs: | |
241 output_dataset = input_dataset | output_dataset | |
242 # copy the wanted attributes in the dataset | |
243 if output_attributes is None: | |
244 output_attributes = self.attributeNames() | |
245 if output_attributes: | |
246 assert set(attribute_names) <= set(self.attributeNames()) | |
247 output_dataset.setAttributes(output_attributes, | |
248 self.names2attributes(output_attributes,return_copy=True)) | |
249 if test_stats_collector: | |
250 test_stats_collector.update(output_dataset) | |
251 if put_stats_in_output_dataset: | |
252 output_dataset.setAttributes(test_stats_collector.attributeNames(), | |
253 test_stats_collector.attributes()) | |
254 return output_dataset | |
255 | 256 |
256 | 257 |
257 class MinibatchUpdatesTLearner(TLearner): | 258 class MinibatchUpdatesTLearner(TLearner): |
258 """ | 259 """ |
259 This adds to TLearner a | 260 This adds to TLearner a |
279 self.names2OpResults(self.updateMinibatchOutputAttributes())) | 280 self.names2OpResults(self.updateMinibatchOutputAttributes())) |
280 self.update_end_function = compile.function | 281 self.update_end_function = compile.function |
281 (self.names2OpResults(self.updateEndInputAttributes()), | 282 (self.names2OpResults(self.updateEndInputAttributes()), |
282 self.names2OpResults(self.updateEndOutputAttributes())) | 283 self.names2OpResults(self.updateEndOutputAttributes())) |
283 | 284 |
285 def allocate(self, minibatch): | |
286 """ | |
287 This function is called at the beginning of each updateMinibatch | |
288 and should be used to check that all required attributes have been | |
289 allocated and initialized (usually this function calls forget() | |
290 when it has to do an initialization). | |
291 """ | |
292 raise AbstractFunction() | |
293 | |
284 def updateMinibatchInputFields(self): | 294 def updateMinibatchInputFields(self): |
285 raise AbstractFunction() | 295 raise AbstractFunction() |
286 | 296 |
287 def updateMinibatchInputAttributes(self): | 297 def updateMinibatchInputAttributes(self): |
288 raise AbstractFunction() | 298 raise AbstractFunction() |