Mercurial > pylearn
comparison dataset.py @ 26:672fe4b23032
Fixed dataset errors so that _test_dataset.py works again.
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Fri, 11 Apr 2008 11:14:54 -0400 |
parents | 526e192b0699 |
children | 541a273bc89f |
comparison
equal
deleted
inserted
replaced
23:526e192b0699 | 26:672fe4b23032 |
---|---|
1 | 1 |
2 from lookup_list import LookupList | 2 from lookup_list import LookupList |
3 Example = LookupList | 3 Example = LookupList |
4 import copy | |
4 | 5 |
5 class AbstractFunction (Exception): """Derived class must override this function""" | 6 class AbstractFunction (Exception): """Derived class must override this function""" |
6 | 7 |
7 class DataSet(object): | 8 class DataSet(object): |
8 """A virtual base class for datasets. | 9 """A virtual base class for datasets. |
140 any other object that supports integer indexing and slicing. | 141 any other object that supports integer indexing and slicing. |
141 | 142 |
142 """ | 143 """ |
143 raise AbstractFunction() | 144 raise AbstractFunction() |
144 | 145 |
145 def hasFields(*fieldnames): | 146 def hasFields(self,*fieldnames): |
146 """ | 147 """ |
147 Return true if the given field name (or field names, if multiple arguments are | 148 Return true if the given field name (or field names, if multiple arguments are |
148 given) is recognized by the DataSet (i.e. can be used as a field name in one | 149 given) is recognized by the DataSet (i.e. can be used as a field name in one |
149 of the iterators). | 150 of the iterators). |
150 """ | 151 """ |
151 raise AbstractFunction() | 152 raise AbstractFunction() |
152 | 153 |
153 def merge_fields(*specifications): | 154 def merge_fields(self,*specifications): |
154 """ | 155 """ |
155 Return a new dataset that maps old fields (of self) to new fields (of the returned | 156 Return a new dataset that maps old fields (of self) to new fields (of the returned |
156 dataset). The minimal syntax that should be supported is the following: | 157 dataset). The minimal syntax that should be supported is the following: |
157 new_field_specifications = [new_field_spec1, new_field_spec2, ...] | 158 new_field_specifications = [new_field_spec1, new_field_spec2, ...] |
158 new_field_spec = ([old_field1, old_field2, ...], new_field) | 159 new_field_spec = ([old_field1, old_field2, ...], new_field) |
160 support additional indexing schemes within each field (e.g. column slice | 161 support additional indexing schemes within each field (e.g. column slice |
161 of a matrix-like field). | 162 of a matrix-like field). |
162 """ | 163 """ |
163 raise AbstractFunction() | 164 raise AbstractFunction() |
164 | 165 |
165 def merge_field_values(*field_value_pairs) | 166 def merge_field_values(self,*field_value_pairs): |
166 """ | 167 """ |
167 Return the value that corresponds to merging the values of several fields, | 168 Return the value that corresponds to merging the values of several fields, |
168 given as arguments (field_name, field_value) pairs with self.hasField(field_name). | 169 given as arguments (field_name, field_value) pairs with self.hasField(field_name). |
169 This may be used by implementations of merge_fields. | 170 This may be used by implementations of merge_fields. |
170 Raise a ValueError if the operation is not possible. | 171 Raise a ValueError if the operation is not possible. |
171 """ | 172 """ |
172 fieldnames,fieldvalues = zip(*field_value_pairs) | 173 fieldnames,fieldvalues = zip(*field_value_pairs) |
173 raise ValueError("Unable to merge values of these fields:"+repr(fieldnames)) | 174 raise ValueError("Unable to merge values of these fields:"+repr(fieldnames)) |
174 | 175 |
175 def examples2minibatch(examples): | 176 def examples2minibatch(self,examples): |
176 """ | 177 """ |
177 Combine a list of Examples into a minibatch. A minibatch is an Example whose fields | 178 Combine a list of Examples into a minibatch. A minibatch is an Example whose fields |
178 are iterable over the examples of the minibatch. | 179 are iterable over the examples of the minibatch. |
179 """ | 180 """ |
180 raise AbstractFunction() | 181 raise AbstractFunction() |
181 | 182 |
182 def rename(rename_dict): | 183 def rename(self,rename_dict): |
183 """ | 184 """ |
184 Return a new dataset that renames fields, using a dictionnary that maps old field | 185 Return a new dataset that renames fields, using a dictionnary that maps old field |
185 names to new field names. The only fields visible by the returned dataset are those | 186 names to new field names. The only fields visible by the returned dataset are those |
186 whose names are keys of the rename_dict. | 187 whose names are keys of the rename_dict. |
187 """ | 188 """ |
188 return RenamingDataSet(self,rename_dict) | 189 self_class = self.__class__ |
189 | 190 class SelfRenamingDataSet(RenamingDataSet,self_class): |
190 def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): | 191 pass |
192 self.__class__ = SelfRenamingDataSet | |
193 # set the rename_dict and src fields | |
194 SelfRenamingDataSet.__init__(self,self,rename_dict) | |
195 return self | |
196 | |
197 def applyFunction(self,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): | |
191 """ | 198 """ |
192 Return a dataset that contains as fields the results of applying | 199 Return a dataset that contains as fields the results of applying |
193 the given function (example-wise) to the specified input_fields. The | 200 the given function (example-wise) to the specified input_fields. The |
194 function should return a sequence whose elements will be stored in | 201 function should return a sequence whose elements will be stored in |
195 fields whose names are given in the output_fields list. If copy_inputs | 202 fields whose names are given in the output_fields list. If copy_inputs |
202 are cached (to avoid recomputation if the same examples are again | 209 are cached (to avoid recomputation if the same examples are again |
203 requested). | 210 requested). |
204 """ | 211 """ |
205 return ApplyFunctionDataSet(function, input_fields, output_fields, copy_inputs, accept_minibatches, cache) | 212 return ApplyFunctionDataSet(function, input_fields, output_fields, copy_inputs, accept_minibatches, cache) |
206 | 213 |
207 class RenamingDataSet(DataSet): | 214 |
215 class FiniteLengthDataSet(DataSet): | |
216 """ | |
217 Virtual interface for datasets that have a finite length (number of examples), | |
218 and thus recognize a len(dataset) call. | |
219 """ | |
220 def __init__(self): | |
221 DataSet.__init__(self) | |
222 | |
223 def __len__(self): | |
224 """len(dataset) returns the number of examples in the dataset.""" | |
225 raise AbstractFunction() | |
226 | |
227 | |
228 class SliceableDataSet(DataSet): | |
229 """ | |
230 Virtual interface, a subclass of DataSet for datasets which are sliceable | |
231 and whose individual elements can be accessed, generally respecting the | |
232 python semantics for [spec], where spec is either a non-negative integer | |
233 (for selecting one example), or a python slice (for selecting a sub-dataset | |
234 comprising the specified examples). This is useful for obtaining | |
235 sub-datasets, e.g. for splitting a dataset into training and test sets. | |
236 """ | |
237 def __init__(self): | |
238 DataSet.__init__(self) | |
239 | |
240 def minibatches(self, | |
241 fieldnames = DataSet.minibatches_fieldnames, | |
242 minibatch_size = DataSet.minibatches_minibatch_size, | |
243 n_batches = DataSet.minibatches_n_batches): | |
244 """ | |
245 If the n_batches is empty, we want to see all the examples possible | |
246 for the given minibatch_size (possibly missing a few at the end of the dataset). | |
247 """ | |
248 # substitute the defaults: | |
249 if n_batches is None: n_batches = len(self) / minibatch_size | |
250 return DataSet.Iterator(self, fieldnames, minibatch_size, n_batches) | |
251 | |
252 def __getitem__(self,i): | |
253 """dataset[i] returns the (i+1)-th example of the dataset.""" | |
254 raise AbstractFunction() | |
255 | |
256 def __getslice__(self,*slice_args): | |
257 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | |
258 raise AbstractFunction() | |
259 | |
260 | |
261 class FiniteWidthDataSet(DataSet): | |
262 """ | |
263 Virtual interface for datasets that have a finite width (number of fields), | |
264 and thus return a list of fieldNames. | |
265 """ | |
266 def __init__(self): | |
267 DataSet.__init__(self) | |
268 | |
269 def hasFields(self,*fields): | |
270 has_fields=True | |
271 fieldnames = self.fieldNames() | |
272 for name in fields: | |
273 if name not in fieldnames: | |
274 has_fields=False | |
275 return has_fields | |
276 | |
277 def fieldNames(self): | |
278 """Return the list of field names that are supported by the iterators, | |
279 and for which hasFields(fieldname) would return True.""" | |
280 raise AbstractFunction() | |
281 | |
282 | |
283 class RenamingDataSet(FiniteWidthDataSet): | |
208 """A DataSet that wraps another one, and makes it look like the field names | 284 """A DataSet that wraps another one, and makes it look like the field names |
209 are different | 285 are different |
210 | 286 |
211 Renaming is done by a dictionary that maps new names to the old ones used in | 287 Renaming is done by a dictionary that maps new names to the old ones used in |
212 self.src. | 288 self.src. |
214 def __init__(self, src, rename_dct): | 290 def __init__(self, src, rename_dct): |
215 DataSet.__init__(self) | 291 DataSet.__init__(self) |
216 self.src = src | 292 self.src = src |
217 self.rename_dct = copy.copy(rename_dct) | 293 self.rename_dct = copy.copy(rename_dct) |
218 | 294 |
295 def fieldNames(self): | |
296 return self.rename_dct.keys() | |
297 | |
219 def minibatches(self, | 298 def minibatches(self, |
220 fieldnames = DataSet.minibatches_fieldnames, | 299 fieldnames = DataSet.minibatches_fieldnames, |
221 minibatch_size = DataSet.minibatches_minibatch_size, | 300 minibatch_size = DataSet.minibatches_minibatch_size, |
222 n_batches = DataSet.minibatches_n_batches): | 301 n_batches = DataSet.minibatches_n_batches): |
223 dct = self.rename_dct | 302 dct = self.rename_dct |
224 new_fieldnames = [dct.get(f, f) for f in fieldnames] | 303 new_fieldnames = [dct.get(f, f) for f in fieldnames] |
225 return self.src.minibatches(new_fieldnames, minibatches_size, n_batches) | 304 return self.src.minibatches(new_fieldnames, minibatches_size, n_batches) |
226 | |
227 class FiniteLengthDataSet(DataSet): | |
228 """ | |
229 Virtual interface for datasets that have a finite length (number of examples), | |
230 and thus recognize a len(dataset) call. | |
231 """ | |
232 def __init__(self): | |
233 DataSet.__init__(self) | |
234 | |
235 def __len__(self): | |
236 """len(dataset) returns the number of examples in the dataset.""" | |
237 raise AbstractFunction() | |
238 | |
239 | |
240 class SliceableDataSet(DataSet): | |
241 """ | |
242 Virtual interface, a subclass of DataSet for datasets which are sliceable | |
243 and whose individual elements can be accessed, generally respecting the | |
244 python semantics for [spec], where spec is either a non-negative integer | |
245 (for selecting one example), or a python slice (for selecting a sub-dataset | |
246 comprising the specified examples). This is useful for obtaining | |
247 sub-datasets, e.g. for splitting a dataset into training and test sets. | |
248 """ | |
249 def __init__(self): | |
250 DataSet.__init__(self) | |
251 | |
252 def minibatches(self, | |
253 fieldnames = DataSet.minibatches_fieldnames, | |
254 minibatch_size = DataSet.minibatches_minibatch_size, | |
255 n_batches = DataSet.minibatches_n_batches): | |
256 """ | |
257 If the n_batches is empty, we want to see all the examples possible | |
258 for the given minibatch_size (possibly missing a few at the end of the dataset). | |
259 """ | |
260 # substitute the defaults: | |
261 if n_batches is None: n_batches = len(self) / minibatch_size | |
262 return DataSet.Iterator(self, fieldnames, minibatch_size, n_batches) | |
263 | |
264 def __getitem__(self,i): | |
265 """dataset[i] returns the (i+1)-th example of the dataset.""" | |
266 raise AbstractFunction() | |
267 | |
268 def __getslice__(self,*slice_args): | |
269 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | |
270 raise AbstractFunction() | |
271 | |
272 | |
273 class FiniteWidthDataSet(DataSet): | |
274 """ | |
275 Virtual interface for datasets that have a finite width (number of fields), | |
276 and thus return a list of fieldNames. | |
277 """ | |
278 def __init__(self): | |
279 DataSet.__init__(self) | |
280 | |
281 def hasFields(*fieldnames): | |
282 has_fields=True | |
283 for fieldname in fieldnames: | |
284 if fieldname not in self.fields.keys(): | |
285 has_fields=False | |
286 return has_fields | |
287 | |
288 def fieldNames(self): | |
289 """Return the list of field names that are supported by the iterators, | |
290 and for which hasFields(fieldname) would return True.""" | |
291 raise AbstractFunction() | |
292 | 305 |
293 | 306 |
294 # we may want ArrayDataSet defined in another python file | 307 # we may want ArrayDataSet defined in another python file |
295 | 308 |
296 import numpy | 309 import numpy |
546 # copy the field here | 559 # copy the field here |
547 result[:,slice(c,c+slice_width)]=self.data[:,field_slice] | 560 result[:,slice(c,c+slice_width)]=self.data[:,field_slice] |
548 c+=slice_width | 561 c+=slice_width |
549 return result | 562 return result |
550 | 563 |
551 def rename(*new_field_specifications): | |
552 """ | |
553 Return a new dataset that maps old fields (of self) to new fields (of the returned | |
554 dataset). The minimal syntax that should be supported is the following: | |
555 new_field_specifications = [new_field_spec1, new_field_spec2, ...] | |
556 new_field_spec = ([old_field1, old_field2, ...], new_field) | |
557 In general both old_field and new_field should be strings, but some datasets may also | |
558 support additional indexing schemes within each field (e.g. column slice | |
559 of a matrix-like field). | |
560 """ | |
561 # if all old fields of each spec are | |
562 raise NotImplementedError() | |
563 | |
564 class ApplyFunctionDataSet(DataSet): | 564 class ApplyFunctionDataSet(DataSet): |
565 """ | 565 """ |
566 A dataset that contains as fields the results of applying | 566 A dataset that contains as fields the results of applying |
567 a given function (example-wise) to specified input_fields of a source | 567 a given function (example-wise) to specified input_fields of a source |
568 dataset. The function should return a sequence whose elements will be stored in | 568 dataset. The function should return a sequence whose elements will be stored in |
597 # and apply the function to it, and transpose into a list of examples (field values, actually) | 597 # and apply the function to it, and transpose into a list of examples (field values, actually) |
598 self.cached_examples = zip(*Example(output_fields,function(*inputs))) | 598 self.cached_examples = zip(*Example(output_fields,function(*inputs))) |
599 else: | 599 else: |
600 # compute a list with one tuple per example, with the function outputs | 600 # compute a list with one tuple per example, with the function outputs |
601 self.cached_examples = [ function(input) for input in src.zip(input_fields) ] | 601 self.cached_examples = [ function(input) for input in src.zip(input_fields) ] |
602 else if cache: | 602 elif cache: |
603 # maybe a fixed-size array kind of structure would be more efficient than a list | 603 # maybe a fixed-size array kind of structure would be more efficient than a list |
604 # in the case where src is FiniteDataSet. -YB | 604 # in the case where src is FiniteDataSet. -YB |
605 self.cached_examples = [] | 605 self.cached_examples = [] |
606 | 606 |
607 def minibatches(self, | 607 def minibatches(self, |