Mercurial > pylearn
comparison dataset.py @ 29:46c5c90019c2
Changed apply_function so that it propagates methods of the source.
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Fri, 11 Apr 2008 15:46:18 -0400 |
parents | 541a273bc89f |
children | 438440ba0627 |
comparison
equal
deleted
inserted
replaced
28:541a273bc89f | 29:46c5c90019c2 |
---|---|
148 Return true if the given field name (or field names, if multiple arguments are | 148 Return true if the given field name (or field names, if multiple arguments are |
149 given) is recognized by the DataSet (i.e. can be used as a field name in one | 149 given) is recognized by the DataSet (i.e. can be used as a field name in one |
150 of the iterators). | 150 of the iterators). |
151 """ | 151 """ |
152 raise AbstractFunction() | 152 raise AbstractFunction() |
153 | |
153 | 154 |
154 def merge_fields(self,*specifications): | 155 def merge_fields(self,*specifications): |
155 """ | 156 """ |
156 Return a new dataset that maps old fields (of self) to new fields (of the returned | 157 Return a new dataset that maps old fields (of self) to new fields (of the returned |
157 dataset). The minimal syntax that should be supported is the following: | 158 dataset). The minimal syntax that should be supported is the following: |
180 """ | 181 """ |
181 raise AbstractFunction() | 182 raise AbstractFunction() |
182 | 183 |
183 def rename(self,rename_dict): | 184 def rename(self,rename_dict): |
184 """ | 185 """ |
185 Return a new dataset that renames fields, using a dictionnary that maps old field | 186 Changes a dataset into one that renames fields, using a dictionnary that maps old field |
186 names to new field names. The only fields visible by the returned dataset are those | 187 names to new field names. The only fields visible by the returned dataset are those |
187 whose names are keys of the rename_dict. | 188 whose names are keys of the rename_dict. |
188 """ | 189 """ |
189 self_class = self.__class__ | 190 self_class = self.__class__ |
190 class SelfRenamingDataSet(RenamingDataSet,self_class): | 191 class SelfRenamingDataSet(RenamingDataSet,self_class): |
192 self.__class__ = SelfRenamingDataSet | 193 self.__class__ = SelfRenamingDataSet |
193 # set the rename_dict and src fields | 194 # set the rename_dict and src fields |
194 SelfRenamingDataSet.__init__(self,self,rename_dict) | 195 SelfRenamingDataSet.__init__(self,self,rename_dict) |
195 return self | 196 return self |
196 | 197 |
197 def applyFunction(self,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): | 198 def apply_function(self,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): |
198 """ | 199 """ |
199 Return a dataset that contains as fields the results of applying | 200 Changes a dataset into one that contains as fields the results of applying |
200 the given function (example-wise) to the specified input_fields. The | 201 the given function (example-wise) to the specified input_fields. The |
201 function should return a sequence whose elements will be stored in | 202 function should return a sequence whose elements will be stored in |
202 fields whose names are given in the output_fields list. If copy_inputs | 203 fields whose names are given in the output_fields list. If copy_inputs |
203 is True then the resulting dataset will also contain the fields of self. | 204 is True then the resulting dataset will also contain the fields of self. |
204 If accept_minibatches, then the function may be called | 205 If accept_minibatches, then the function may be called |
207 of the resulting dataset are requested. If cache is True, then | 208 of the resulting dataset are requested. If cache is True, then |
208 once the output fields for some examples have been computed, then | 209 once the output fields for some examples have been computed, then |
209 are cached (to avoid recomputation if the same examples are again | 210 are cached (to avoid recomputation if the same examples are again |
210 requested). | 211 requested). |
211 """ | 212 """ |
212 return ApplyFunctionDataSet(function, input_fields, output_fields, copy_inputs, accept_minibatches, cache) | 213 self_class = self.__class__ |
214 class SelfApplyFunctionDataSet(ApplyFunctionDataSet,self_class): | |
215 pass | |
216 self.__class__ = SelfApplyFunctionDataSet | |
217 # set the required additional fields | |
218 ApplyFunctionDataSet.__init__(self,self,function, input_fields, output_fields, copy_inputs, accept_minibatches, cache) | |
219 return self | |
213 | 220 |
214 | 221 |
215 class FiniteLengthDataSet(DataSet): | 222 class FiniteLengthDataSet(DataSet): |
216 """ | 223 """ |
217 Virtual interface for datasets that have a finite length (number of examples), | 224 Virtual interface for datasets that have a finite length (number of examples), |
221 DataSet.__init__(self) | 228 DataSet.__init__(self) |
222 | 229 |
223 def __len__(self): | 230 def __len__(self): |
224 """len(dataset) returns the number of examples in the dataset.""" | 231 """len(dataset) returns the number of examples in the dataset.""" |
225 raise AbstractFunction() | 232 raise AbstractFunction() |
226 | 233 |
234 def __call__(self,fieldname_or_fieldnames): | |
235 """ | |
236 Extract one or more fields. This may be an expensive operation when the | |
237 dataset is large. It is not the recommanded way to access individual values | |
238 (use the iterators instead). If the argument is a string fieldname, then the result | |
239 is a sequence (iterable object) of values for that field, for the whole dataset. If the | |
240 argument is a list of field names, then the result is a 'batch', i.e., an Example with keys | |
241 corresponding to the given field names and values being iterable objects over the | |
242 individual example values. | |
243 """ | |
244 if type(fieldname_or_fieldnames) is string: | |
245 minibatch = self.minibatches([fieldname_or_fieldnames],len(self)).next() | |
246 return minibatch[fieldname_or_fieldnames] | |
247 return self.minibatches(fieldname_or_fieldnames,len(self)).next() | |
227 | 248 |
228 class SliceableDataSet(DataSet): | 249 class SliceableDataSet(DataSet): |
229 """ | 250 """ |
230 Virtual interface, a subclass of DataSet for datasets which are sliceable | 251 Virtual interface, a subclass of DataSet for datasets which are sliceable |
231 and whose individual elements can be accessed, generally respecting the | 252 and whose individual elements can be accessed, generally respecting the |
470 for the given minibatch_size (possibly missing some near the end). | 491 for the given minibatch_size (possibly missing some near the end). |
471 """ | 492 """ |
472 # substitute the defaults: | 493 # substitute the defaults: |
473 if n_batches is None: n_batches = len(self) / minibatch_size | 494 if n_batches is None: n_batches = len(self) / minibatch_size |
474 return ArrayDataSet.Iterator(self, fieldnames, minibatch_size, n_batches) | 495 return ArrayDataSet.Iterator(self, fieldnames, minibatch_size, n_batches) |
475 | |
476 def __getattr__(self,fieldname): | |
477 """ | |
478 Return a numpy array with the content associated with the given field name. | |
479 If this is a one-example dataset, then a row, i.e., numpy array (of one less dimension | |
480 than the dataset itself) is returned. | |
481 """ | |
482 if len(self.data)==1: | |
483 return self.data[0,self.fields[fieldname]] | |
484 return self.data[:,self.fields[fieldname]] | |
485 | |
486 def __call__(self,*fieldnames): | |
487 """Return a sub-dataset containing only the given fieldnames as fields.""" | |
488 return ArrayDataSet(self.data,fields=LookupList(fieldnames,[self.fields[fieldname] for fieldname in fieldnames])) | |
489 | 496 |
490 def fieldNames(self): | 497 def fieldNames(self): |
491 """Return the list of field names that are supported by getattr and hasField.""" | 498 """Return the list of field names that are supported by getattr and hasField.""" |
492 return self.fields.keys() | 499 return self.fields.keys() |
493 | 500 |
558 else: | 565 else: |
559 step = j-i | 566 step = j-i |
560 i=j | 567 i=j |
561 return slice(start,stop,step) | 568 return slice(start,stop,step) |
562 | 569 |
563 class ApplyFunctionDataSet(DataSet): | 570 class ApplyFunctionDataSet(FiniteWidthDataSet): |
564 """ | 571 """ |
565 A dataset that contains as fields the results of applying | 572 A dataset that contains as fields the results of applying |
566 a given function (example-wise) to specified input_fields of a source | 573 a given function (example-wise) to specified input_fields of a source |
567 dataset. The function should return a sequence whose elements will be stored in | 574 dataset. The function should return a sequence whose elements will be stored in |
568 fields whose names are given in the output_fields list. If copy_inputs | 575 fields whose names are given in the output_fields list. If copy_inputs |
601 elif cache: | 608 elif cache: |
602 # maybe a fixed-size array kind of structure would be more efficient than a list | 609 # maybe a fixed-size array kind of structure would be more efficient than a list |
603 # in the case where src is FiniteDataSet. -YB | 610 # in the case where src is FiniteDataSet. -YB |
604 self.cached_examples = [] | 611 self.cached_examples = [] |
605 | 612 |
613 def fieldNames(self): | |
614 if self.copy_inputs: | |
615 return self.output_fields + self.src.fieldNames() | |
616 return self.output_fields | |
617 | |
606 def minibatches(self, | 618 def minibatches(self, |
607 fieldnames = DataSet.minibatches_fieldnames, | 619 fieldnames = DataSet.minibatches_fieldnames, |
608 minibatch_size = DataSet.minibatches_minibatch_size, | 620 minibatch_size = DataSet.minibatches_minibatch_size, |
609 n_batches = DataSet.minibatches_n_batches): | 621 n_batches = DataSet.minibatches_n_batches): |
610 | 622 |