comparison learner.py @ 110:8fa1ef2411a0

Worked on OneShotTLearner and implementation of LinearRegression
author bengioy@bengiomac.local
date Tue, 06 May 2008 22:24:55 -0400
parents d97f6fe6bdf9
children 88257dfedf8c
comparison
equal deleted inserted replaced
109:d97f6fe6bdf9 110:8fa1ef2411a0
1 1
2 from dataset import * 2 from dataset import *
3 3
4 class Learner(object): 4 class Learner(AttributesHolder):
5 """Base class for learning algorithms, provides an interface 5 """Base class for learning algorithms, provides an interface
6 that allows various algorithms to be applicable to generic learning 6 that allows various algorithms to be applicable to generic learning
7 algorithms. 7 algorithms.
8 8
9 A Learner can be seen as a learning algorithm, a function that when 9 A Learner can be seen as a learning algorithm, a function that when
64 64
65 @todo By default, attributeNames looks for all dictionary entries whose name does not start with _. 65 @todo By default, attributeNames looks for all dictionary entries whose name does not start with _.
66 """ 66 """
67 return [] 67 return []
68 68
69 def updateInputAttributes(self):
70 """
71 A subset of self.attributeNames() which are the names of attributes needed by update() in order
72 to do its work.
73 """
74 raise AbstractFunction()
75
76 def useInputAttributes(self):
77 """
78 A subset of self.attributeNames() which are the names of attributes needed by use() in order
79 to do its work.
80 """
81 raise AbstractFunction()
82
83 def updateOutputAttributes(self):
84 """
85 A subset of self.attributeNames() which are the names of attributes modified/created by update() in order
86 to do its work.
87 """
88 raise AbstractFunction()
89
90 def useOutputAttributes(self):
91 """
92 A subset of self.attributeNames() which are the names of attributes modified/created by use() in order
93 to do its work.
94 """
95 raise AbstractFunction()
96
97
69 class TLearner(Learner): 98 class TLearner(Learner):
70 """ 99 """
71 TLearner is a virtual class of Learners that attempts to factor out of the definition 100 TLearner is a virtual class of Learners that attempts to factor out of the definition
72 of a learner the steps that are common to many implementations of learning algorithms, 101 of a learner the steps that are common to many implementations of learning algorithms,
73 so as to leave only 'the equations' to define in particular sub-classes, using Theano. 102 so as to leave only 'the equations' to define in particular sub-classes, using Theano.
101 dependant de Theano 130 dependant de Theano
102 """ 131 """
103 132
104 def __init__(self): 133 def __init__(self):
105 Learner.__init__(self) 134 Learner.__init__(self)
135
136 def defaultOutputFields(self, input_fields):
137 """
138 Return a default list of output field names (to put in the output dataset).
139 This will be used when None are provided (as output_fields) by the caller of the 'use' method.
140 This may involve looking at the input_fields (names) available in the
141 input_dataset.
142 """
143 raise AbstractFunction()
144
145 def allocate(self, minibatch):
146 """
147 This function is called at the beginning of each updateMinibatch
148 and should be used to check that all required attributes have been
149 allocated and initialized (usually this function calls forget()
150 when it has to do an initialization).
151 """
152 raise AbstractFunction()
106 153
107 def _minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): 154 def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
108 """ 155 """
109 Private helper function called by the generic TLearner.use. It returns a function 156 Private helper function called by the generic TLearner.use. It returns a function
110 that can map the given input fields to the given output fields (along with the 157 that can map the given input fields to the given output fields (along with the
111 attributes that the stats collector needs for its computation. 158 attributes that the stats collector needs for its computation. The function
159 called also automatically makes use of the self.useInputAttributes() and
160 sets the self.useOutputAttributes().
112 """ 161 """
113 if not output_fields: 162 if not output_fields:
114 output_fields = self.defaultOutputFields(input_fields) 163 output_fields = self.defaultOutputFields(input_fields)
115 if stats_collector: 164 if stats_collector:
116 stats_collector_inputs = stats_collector.inputUpdateAttributes() 165 stats_collector_inputs = stats_collector.input2UpdateAttributes()
117 for attribute in stats_collector_inputs: 166 for attribute in stats_collector_inputs:
118 if attribute not in input_fields: 167 if attribute not in input_fields:
119 output_fields.append(attribute) 168 output_fields.append(attribute)
120 key = (input_fields,output_fields) 169 key = (input_fields,output_fields)
121 if key not in self.use_functions_dictionary: 170 if key not in self.use_functions_dictionary:
122 self.use_functions_dictionary[key]=Function(self._names2attributes(input_fields), 171 use_input_attributes = self.useInputAttributes()
123 self._names2attributes(output_fields)) 172 use_output_attributes = self.useOutputAttributes()
173 complete_f = Function(self.names2OpResults(input_fields+use_input_attributes),
174 self.names2OpResults(output_fields+use_output_attributes))
175 def f(*input_field_values):
176 input_attribute_values = self.names2attributes(use_input_attributes)
177 results = complete_f(*(input_field_values + input_attribute_values))
178 output_field_values = results[0:len(output_fields)]
179 output_attribute_values = results[len(output_fields):len(results)]
180 if use_output_attributes:
181 self.setAttributes(use_output_attributes,output_attribute_values)
182 return output_field_values
183 self.use_functions_dictionary[key]=f
124 return self.use_functions_dictionary[key] 184 return self.use_functions_dictionary[key]
125 185
126 def attributes(self,return_copy=False): 186 def attributes(self,return_copy=False):
127 """ 187 """
128 Return a list with the values of the learner's attributes (or optionally, a deep copy). 188 Return a list with the values of the learner's attributes (or optionally, a deep copy).
129 """ 189 """
130 return self.names2attributes(self.attributeNames()) 190 return self.names2attributes(self.attributeNames(),return_copy)
131 191
132 def _names2attributes(self,names,return_Result=False, return_copy=False): 192 def names2attributes(self,names,return_copy=False):
133 """ 193 """
134 Private helper function that maps a list of attribute names to a list 194 Private helper function that maps a list of attribute names to a list
135 of (optionally copies) values or of the Result objects that own these values. 195 of (optionally copies) values of attributes.
136 """ 196 """
137 if return_Result: 197 if return_copy:
138 if return_copy: 198 return [copy.deepcopy(self.__getattr__(name).data) for name in names]
139 return [copy.deepcopy(self.__getattr__(name)) for name in names]
140 else:
141 return [self.__getattr__(name) for name in names]
142 else: 199 else:
143 if return_copy: 200 return [self.__getattr__(name).data for name in names]
144 return [copy.deepcopy(self.__getattr__(name).data) for name in names] 201
145 else: 202 def names2OpResults(self,names):
146 return [self.__getattr__(name).data for name in names] 203 """
204 Private helper function that maps a list of attribute names to a list
205 of corresponding Op Results (with the same name but with a '_' prefix).
206 """
207 return [self.__getattr__('_'+name).data for name in names]
147 208
148 def use(self,input_dataset,output_fieldnames=None,output_attributes=[], 209 def use(self,input_dataset,output_fieldnames=None,output_attributes=[],
149 test_stats_collector=None,copy_inputs=True): 210 test_stats_collector=None,copy_inputs=True, put_stats_in_output_dataset=True):
150 """ 211 """
151 The learner tries to compute in the output dataset the output fields specified 212 The learner tries to compute in the output dataset the output fields specified
152 213
153 @todo check if some of the learner attributes are actually SPECIFIED 214 @todo check if some of the learner attributes are actually SPECIFIED
154 as attributes of the input_dataset, and if so use their values instead 215 as attributes of the input_dataset, and if so use their values instead
162 If output_attributes is None then all of the attributes in self.AttributeNames() 223 If output_attributes is None then all of the attributes in self.AttributeNames()
163 are copied in the output dataset, but if it is [] (the default), then none are copied. 224 are copied in the output dataset, but if it is [] (the default), then none are copied.
164 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) 225 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames())
165 are also copied into the output dataset attributes. 226 are also copied into the output dataset attributes.
166 """ 227 """
167 minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(), 228 minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(),
168 output_fieldnames, 229 output_fieldnames,
169 test_stats_collector) 230 test_stats_collector)
170 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, 231 virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
171 minibatchwise_use_function, 232 minibatchwise_use_function,
172 True,DataSet.numpy_vstack, 233 True,DataSet.numpy_vstack,
177 output_dataset = input_dataset | output_dataset 238 output_dataset = input_dataset | output_dataset
178 # copy the wanted attributes in the dataset 239 # copy the wanted attributes in the dataset
179 if output_attributes is None: 240 if output_attributes is None:
180 output_attributes = self.attributeNames() 241 output_attributes = self.attributeNames()
181 if output_attributes: 242 if output_attributes:
182 assert set(output_attributes) <= set(self.attributeNames()) 243 assert set(attribute_names) <= set(self.attributeNames())
183 output_dataset.setAttributes(output_attributes, 244 output_dataset.setAttributes(output_attributes,
184 self._names2attributes(output_attributes,return_copy=True)) 245 self.names2attributes(output_attributes,return_copy=True))
185 if test_stats_collector: 246 if test_stats_collector:
186 test_stats_collector.update(output_dataset) 247 test_stats_collector.update(output_dataset)
187 output_dataset.setAttributes(test_stats_collector.attributeNames(), 248 if put_stats_in_output_dataset:
188 test_stats_collector.attributes()) 249 output_dataset.setAttributes(test_stats_collector.attributeNames(),
250 test_stats_collector.attributes())
189 return output_dataset 251 return output_dataset
190 252
191 253
192 class OneShotTLearner(TLearner): 254 class OneShotTLearner(TLearner):
193 """ 255 """
194 This adds to TLearner a 256 This adds to TLearner a
195 - update_start(), update_end(), update_minibatch(minibatch), end_epoch(): 257 - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch():
196 functions executed at the beginning, the end, in the middle 258 functions executed at the beginning, the end, in the middle
197 (for each minibatch) of the update method, and at the end 259 (for each minibatch) of the update method, and at the end
198 of each epoch. This model only 260 of each epoch. This model only
199 works for 'online' or one-shot learning that requires 261 works for 'online' or one-shot learning that requires
200 going only once through the training data. For more complicated 262 going only once through the training data. For more complicated
202 or a learning-algorithm specific update method should be defined. 264 or a learning-algorithm specific update method should be defined.
203 """ 265 """
204 266
205 def __init__(self): 267 def __init__(self):
206 TLearner.__init__(self) 268 TLearner.__init__(self)
269 self.update_minibatch_function =
270 Function(self.names2OpResults(self.updateMinibatchOutputAttributes()+
271 self.updateMinibatchInputFields()),
272 self.names2OpResults(self.updateMinibatchOutputAttributes()))
273 self.update_end_function = Function(self.names2OpResults(self.updateEndInputAttributes()),
274 self.names2OpResults(self.updateEndOutputAttributes()))
275
276 def updateMinibatchInputFields(self):
277 raise AbstractFunction()
278
279 def updateMinibatchInputAttributes(self):
280 raise AbstractFunction()
281
282 def updateMinibatchOutputAttributes(self):
283 raise AbstractFunction()
284
285 def updateEndInputAttributes(self):
286 raise AbstractFunction()
287
288 def updateEndOutputAttributes(self):
289 raise AbstractFunction()
290
291 def updateStart(self): pass
292
293 def updateEnd(self):
294 self.setAttributes(self.updateEndOutputAttributes(),
295 self.update_end_function
296 (self.names2attributes(self.updateEndInputAttributes())))
207 297
208 def update_start(self): pass 298 def updateMinibatch(self,minibatch):
209 def update_end(self): pass 299 # make sure all required fields are allocated and initialized
210 def update_minibatch(self,minibatch): 300 self.allocate(minibatch)
211 raise AbstractFunction() 301 self.setAttributes(self.updateMinibatchOutputAttributes(),
302 self.update_minibatch_function(*(self.names2attributes(self.updateMinibatchInputAttributes()))
303 + minibatch(self.updateMinibatchInputFields())))
304
305 def isLastEpoch(self):
306 """
307 This method is called at the end of each epoch (cycling over the training set).
308 It returns a boolean to indicate if this is the last epoch.
309 By default just do one epoch.
310 """
311 return True
212 312
213 def update(self,training_set,train_stats_collector=None): 313 def update(self,training_set,train_stats_collector=None):
214 """ 314 """
215 @todo check if some of the learner attributes are actually SPECIFIED 315 @todo check if some of the learner attributes are actually SPECIFIED
216 in as attributes of the training_set. 316 in as attributes of the training_set.
217 """ 317 """
218 self.update_start() 318 self.updateStart(training_set)
219 stop=False 319 stop=False
220 while not stop: 320 while not stop:
221 if train_stats_collector: 321 if train_stats_collector:
222 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch 322 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch
223 for minibatch in training_set.minibatches(self.training_set_input_fields, 323 for minibatch in training_set.minibatches(self.training_set_input_fields,
225 self.update_minibatch(minibatch) 325 self.update_minibatch(minibatch)
226 if train_stats_collector: 326 if train_stats_collector:
227 minibatch_set = minibatch.examples() 327 minibatch_set = minibatch.examples()
228 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) 328 minibatch_set.setAttributes(self.attributeNames(),self.attributes())
229 train_stats_collector.update(minibatch_set) 329 train_stats_collector.update(minibatch_set)
230 stop = self.end_epoch() 330 stop = self.isLastEpoch()
231 self.update_end() 331 self.updateEnd()
232 return self.use 332 return self.use
233 333