Mercurial > pylearn
comparison learner.py @ 110:8fa1ef2411a0
Worked on OneShotTLearner and implementation of LinearRegression
author | bengioy@bengiomac.local |
---|---|
date | Tue, 06 May 2008 22:24:55 -0400 |
parents | d97f6fe6bdf9 |
children | 88257dfedf8c |
comparison
equal
deleted
inserted
replaced
109:d97f6fe6bdf9 | 110:8fa1ef2411a0 |
---|---|
1 | 1 |
2 from dataset import * | 2 from dataset import * |
3 | 3 |
4 class Learner(object): | 4 class Learner(AttributesHolder): |
5 """Base class for learning algorithms, provides an interface | 5 """Base class for learning algorithms, provides an interface |
6 that allows various algorithms to be applicable to generic learning | 6 that allows various algorithms to be applicable to generic learning |
7 algorithms. | 7 algorithms. |
8 | 8 |
9 A Learner can be seen as a learning algorithm, a function that when | 9 A Learner can be seen as a learning algorithm, a function that when |
64 | 64 |
65 @todo By default, attributeNames looks for all dictionary entries whose name does not start with _. | 65 @todo By default, attributeNames looks for all dictionary entries whose name does not start with _. |
66 """ | 66 """ |
67 return [] | 67 return [] |
68 | 68 |
69 def updateInputAttributes(self): | |
70 """ | |
71 A subset of self.attributeNames() which are the names of attributes needed by update() in order | |
72 to do its work. | |
73 """ | |
74 raise AbstractFunction() | |
75 | |
76 def useInputAttributes(self): | |
77 """ | |
78 A subset of self.attributeNames() which are the names of attributes needed by use() in order | |
79 to do its work. | |
80 """ | |
81 raise AbstractFunction() | |
82 | |
83 def updateOutputAttributes(self): | |
84 """ | |
85 A subset of self.attributeNames() which are the names of attributes modified/created by update() in order | |
86 to do its work. | |
87 """ | |
88 raise AbstractFunction() | |
89 | |
90 def useOutputAttributes(self): | |
91 """ | |
92 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order | |
93 to do its work. | |
94 """ | |
95 raise AbstractFunction() | |
96 | |
97 | |
69 class TLearner(Learner): | 98 class TLearner(Learner): |
70 """ | 99 """ |
71 TLearner is a virtual class of Learners that attempts to factor out of the definition | 100 TLearner is a virtual class of Learners that attempts to factor out of the definition |
72 of a learner the steps that are common to many implementations of learning algorithms, | 101 of a learner the steps that are common to many implementations of learning algorithms, |
73 so as to leave only 'the equations' to define in particular sub-classes, using Theano. | 102 so as to leave only 'the equations' to define in particular sub-classes, using Theano. |
101 dependant de Theano | 130 dependant de Theano |
102 """ | 131 """ |
103 | 132 |
104 def __init__(self): | 133 def __init__(self): |
105 Learner.__init__(self) | 134 Learner.__init__(self) |
135 | |
136 def defaultOutputFields(self, input_fields): | |
137 """ | |
138 Return a default list of output field names (to put in the output dataset). | |
139 This will be used when None are provided (as output_fields) by the caller of the 'use' method. | |
140 This may involve looking at the input_fields (names) available in the | |
141 input_dataset. | |
142 """ | |
143 raise AbstractFunction() | |
144 | |
145 def allocate(self, minibatch): | |
146 """ | |
147 This function is called at the beginning of each updateMinibatch | |
148 and should be used to check that all required attributes have been | |
149 allocated and initialized (usually this function calls forget() | |
150 when it has to do an initialization). | |
151 """ | |
152 raise AbstractFunction() | |
106 | 153 |
107 def _minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): | 154 def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): |
108 """ | 155 """ |
109 Private helper function called by the generic TLearner.use. It returns a function | 156 Private helper function called by the generic TLearner.use. It returns a function |
110 that can map the given input fields to the given output fields (along with the | 157 that can map the given input fields to the given output fields (along with the |
111 attributes that the stats collector needs for its computation. | 158 attributes that the stats collector needs for its computation. The function |
159 called also automatically makes use of the self.useInputAttributes() and | |
160 sets the self.useOutputAttributes(). | |
112 """ | 161 """ |
113 if not output_fields: | 162 if not output_fields: |
114 output_fields = self.defaultOutputFields(input_fields) | 163 output_fields = self.defaultOutputFields(input_fields) |
115 if stats_collector: | 164 if stats_collector: |
116 stats_collector_inputs = stats_collector.inputUpdateAttributes() | 165 stats_collector_inputs = stats_collector.input2UpdateAttributes() |
117 for attribute in stats_collector_inputs: | 166 for attribute in stats_collector_inputs: |
118 if attribute not in input_fields: | 167 if attribute not in input_fields: |
119 output_fields.append(attribute) | 168 output_fields.append(attribute) |
120 key = (input_fields,output_fields) | 169 key = (input_fields,output_fields) |
121 if key not in self.use_functions_dictionary: | 170 if key not in self.use_functions_dictionary: |
122 self.use_functions_dictionary[key]=Function(self._names2attributes(input_fields), | 171 use_input_attributes = self.useInputAttributes() |
123 self._names2attributes(output_fields)) | 172 use_output_attributes = self.useOutputAttributes() |
173 complete_f = Function(self.names2OpResults(input_fields+use_input_attributes), | |
174 self.names2OpResults(output_fields+use_output_attributes)) | |
175 def f(*input_field_values): | |
176 input_attribute_values = self.names2attributes(use_input_attributes) | |
177 results = complete_f(*(input_field_values + input_attribute_values)) | |
178 output_field_values = results[0:len(output_fields)] | |
179 output_attribute_values = results[len(output_fields):len(results)] | |
180 if use_output_attributes: | |
181 self.setAttributes(use_output_attributes,output_attribute_values) | |
182 return output_field_values | |
183 self.use_functions_dictionary[key]=f | |
124 return self.use_functions_dictionary[key] | 184 return self.use_functions_dictionary[key] |
125 | 185 |
126 def attributes(self,return_copy=False): | 186 def attributes(self,return_copy=False): |
127 """ | 187 """ |
128 Return a list with the values of the learner's attributes (or optionally, a deep copy). | 188 Return a list with the values of the learner's attributes (or optionally, a deep copy). |
129 """ | 189 """ |
130 return self.names2attributes(self.attributeNames()) | 190 return self.names2attributes(self.attributeNames(),return_copy) |
131 | 191 |
132 def _names2attributes(self,names,return_Result=False, return_copy=False): | 192 def names2attributes(self,names,return_copy=False): |
133 """ | 193 """ |
134 Private helper function that maps a list of attribute names to a list | 194 Private helper function that maps a list of attribute names to a list |
135 of (optionally copies) values or of the Result objects that own these values. | 195 of (optionally copies) values of attributes. |
136 """ | 196 """ |
137 if return_Result: | 197 if return_copy: |
138 if return_copy: | 198 return [copy.deepcopy(self.__getattr__(name).data) for name in names] |
139 return [copy.deepcopy(self.__getattr__(name)) for name in names] | |
140 else: | |
141 return [self.__getattr__(name) for name in names] | |
142 else: | 199 else: |
143 if return_copy: | 200 return [self.__getattr__(name).data for name in names] |
144 return [copy.deepcopy(self.__getattr__(name).data) for name in names] | 201 |
145 else: | 202 def names2OpResults(self,names): |
146 return [self.__getattr__(name).data for name in names] | 203 """ |
204 Private helper function that maps a list of attribute names to a list | |
205 of corresponding Op Results (with the same name but with a '_' prefix). | |
206 """ | |
207 return [self.__getattr__('_'+name).data for name in names] | |
147 | 208 |
148 def use(self,input_dataset,output_fieldnames=None,output_attributes=[], | 209 def use(self,input_dataset,output_fieldnames=None,output_attributes=[], |
149 test_stats_collector=None,copy_inputs=True): | 210 test_stats_collector=None,copy_inputs=True, put_stats_in_output_dataset=True): |
150 """ | 211 """ |
151 The learner tries to compute in the output dataset the output fields specified | 212 The learner tries to compute in the output dataset the output fields specified |
152 | 213 |
153 @todo check if some of the learner attributes are actually SPECIFIED | 214 @todo check if some of the learner attributes are actually SPECIFIED |
154 as attributes of the input_dataset, and if so use their values instead | 215 as attributes of the input_dataset, and if so use their values instead |
162 If output_attributes is None then all of the attributes in self.AttributeNames() | 223 If output_attributes is None then all of the attributes in self.AttributeNames() |
163 are copied in the output dataset, but if it is [] (the default), then none are copied. | 224 are copied in the output dataset, but if it is [] (the default), then none are copied. |
164 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) | 225 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) |
165 are also copied into the output dataset attributes. | 226 are also copied into the output dataset attributes. |
166 """ | 227 """ |
167 minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(), | 228 minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(), |
168 output_fieldnames, | 229 output_fieldnames, |
169 test_stats_collector) | 230 test_stats_collector) |
170 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, | 231 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, |
171 minibatchwise_use_function, | 232 minibatchwise_use_function, |
172 True,DataSet.numpy_vstack, | 233 True,DataSet.numpy_vstack, |
177 output_dataset = input_dataset | output_dataset | 238 output_dataset = input_dataset | output_dataset |
178 # copy the wanted attributes in the dataset | 239 # copy the wanted attributes in the dataset |
179 if output_attributes is None: | 240 if output_attributes is None: |
180 output_attributes = self.attributeNames() | 241 output_attributes = self.attributeNames() |
181 if output_attributes: | 242 if output_attributes: |
182 assert set(output_attributes) <= set(self.attributeNames()) | 243 assert set(attribute_names) <= set(self.attributeNames()) |
183 output_dataset.setAttributes(output_attributes, | 244 output_dataset.setAttributes(output_attributes, |
184 self._names2attributes(output_attributes,return_copy=True)) | 245 self.names2attributes(output_attributes,return_copy=True)) |
185 if test_stats_collector: | 246 if test_stats_collector: |
186 test_stats_collector.update(output_dataset) | 247 test_stats_collector.update(output_dataset) |
187 output_dataset.setAttributes(test_stats_collector.attributeNames(), | 248 if put_stats_in_output_dataset: |
188 test_stats_collector.attributes()) | 249 output_dataset.setAttributes(test_stats_collector.attributeNames(), |
250 test_stats_collector.attributes()) | |
189 return output_dataset | 251 return output_dataset |
190 | 252 |
191 | 253 |
192 class OneShotTLearner(TLearner): | 254 class OneShotTLearner(TLearner): |
193 """ | 255 """ |
194 This adds to TLearner a | 256 This adds to TLearner a |
195 - update_start(), update_end(), update_minibatch(minibatch), end_epoch(): | 257 - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch(): |
196 functions executed at the beginning, the end, in the middle | 258 functions executed at the beginning, the end, in the middle |
197 (for each minibatch) of the update method, and at the end | 259 (for each minibatch) of the update method, and at the end |
198 of each epoch. This model only | 260 of each epoch. This model only |
199 works for 'online' or one-shot learning that requires | 261 works for 'online' or one-shot learning that requires |
200 going only once through the training data. For more complicated | 262 going only once through the training data. For more complicated |
202 or a learning-algorithm specific update method should be defined. | 264 or a learning-algorithm specific update method should be defined. |
203 """ | 265 """ |
204 | 266 |
205 def __init__(self): | 267 def __init__(self): |
206 TLearner.__init__(self) | 268 TLearner.__init__(self) |
269 self.update_minibatch_function = | |
270 Function(self.names2OpResults(self.updateMinibatchOutputAttributes()+ | |
271 self.updateMinibatchInputFields()), | |
272 self.names2OpResults(self.updateMinibatchOutputAttributes())) | |
273 self.update_end_function = Function(self.names2OpResults(self.updateEndInputAttributes()), | |
274 self.names2OpResults(self.updateEndOutputAttributes())) | |
275 | |
276 def updateMinibatchInputFields(self): | |
277 raise AbstractFunction() | |
278 | |
279 def updateMinibatchInputAttributes(self): | |
280 raise AbstractFunction() | |
281 | |
282 def updateMinibatchOutputAttributes(self): | |
283 raise AbstractFunction() | |
284 | |
285 def updateEndInputAttributes(self): | |
286 raise AbstractFunction() | |
287 | |
288 def updateEndOutputAttributes(self): | |
289 raise AbstractFunction() | |
290 | |
291 def updateStart(self): pass | |
292 | |
293 def updateEnd(self): | |
294 self.setAttributes(self.updateEndOutputAttributes(), | |
295 self.update_end_function | |
296 (self.names2attributes(self.updateEndInputAttributes()))) | |
207 | 297 |
208 def update_start(self): pass | 298 def updateMinibatch(self,minibatch): |
209 def update_end(self): pass | 299 # make sure all required fields are allocated and initialized |
210 def update_minibatch(self,minibatch): | 300 self.allocate(minibatch) |
211 raise AbstractFunction() | 301 self.setAttributes(self.updateMinibatchOutputAttributes(), |
302 self.update_minibatch_function(*(self.names2attributes(self.updateMinibatchInputAttributes())) | |
303 + minibatch(self.updateMinibatchInputFields()))) | |
304 | |
305 def isLastEpoch(self): | |
306 """ | |
307 This method is called at the end of each epoch (cycling over the training set). | |
308 It returns a boolean to indicate if this is the last epoch. | |
309 By default just do one epoch. | |
310 """ | |
311 return True | |
212 | 312 |
213 def update(self,training_set,train_stats_collector=None): | 313 def update(self,training_set,train_stats_collector=None): |
214 """ | 314 """ |
215 @todo check if some of the learner attributes are actually SPECIFIED | 315 @todo check if some of the learner attributes are actually SPECIFIED |
216 in as attributes of the training_set. | 316 in as attributes of the training_set. |
217 """ | 317 """ |
218 self.update_start() | 318 self.updateStart(training_set) |
219 stop=False | 319 stop=False |
220 while not stop: | 320 while not stop: |
221 if train_stats_collector: | 321 if train_stats_collector: |
222 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch | 322 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch |
223 for minibatch in training_set.minibatches(self.training_set_input_fields, | 323 for minibatch in training_set.minibatches(self.training_set_input_fields, |
225 self.update_minibatch(minibatch) | 325 self.update_minibatch(minibatch) |
226 if train_stats_collector: | 326 if train_stats_collector: |
227 minibatch_set = minibatch.examples() | 327 minibatch_set = minibatch.examples() |
228 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) | 328 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) |
229 train_stats_collector.update(minibatch_set) | 329 train_stats_collector.update(minibatch_set) |
230 stop = self.end_epoch() | 330 stop = self.isLastEpoch() |
231 self.update_end() | 331 self.updateEnd() |
232 return self.use | 332 return self.use |
233 | 333 |