Mercurial > pylearn
comparison dataset.py @ 22:b6b36f65664f
Created virtual sub-classes of DataSet: {Finite{Length,Width},Sliceable}DataSet,
removed .field ability from LookupList (because of setattr problems), removed
fieldNames() from DataSet (but is in FiniteWidthDataSet, where it makes sense),
and added hasFields() instead. Fixed problems in asarray, and tested
previous functionality in _test_dataset.py, but not yet new functionality.
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Mon, 07 Apr 2008 20:44:37 -0400 |
parents | 266c68cb6136 |
children | 526e192b0699 |
comparison
equal
deleted
inserted
replaced
21:fdf0abc490f7 | 22:b6b36f65664f |
---|---|
8 """A virtual base class for datasets. | 8 """A virtual base class for datasets. |
9 | 9 |
10 A DataSet is a generator of iterators; these iterators can run through the | 10 A DataSet is a generator of iterators; these iterators can run through the |
11 examples in a variety of ways. A DataSet need not necessarily have a finite | 11 examples in a variety of ways. A DataSet need not necessarily have a finite |
12 or known length, so this class can be used to interface to a 'stream' which | 12 or known length, so this class can be used to interface to a 'stream' which |
13 feeds on-line learning. | 13 feeds on-line learning. |
14 | 14 |
15 To iterate over examples, there are several possibilities: | 15 To iterate over examples, there are several possibilities: |
16 - for example in dataset.zip([field1, field2,field3, ...]) | 16 - for example in dataset.zip([field1, field2,field3, ...]) |
17 - for val1,val2,val3 in dataset.zip([field1, field2,field3]) | 17 - for val1,val2,val3 in dataset.zip([field1, field2,field3]) |
18 - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N) | 18 - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N) |
19 - for example in dataset | 19 - for example in dataset |
20 Each of these is documented below. | 20 Each of these is documented below. |
21 | 21 |
22 Note: For a dataset of fixed and known length, which can implement item | |
23 random-access efficiently (e.g. indexing and slicing), and which can profit | |
24 from the FiniteDataSetIterator, consider using base class FiniteDataSet. | |
25 | |
26 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. | 22 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. |
27 | 23 |
28 Note: The content of a field can be of any type. | 24 Note: The content of a field can be of any type. |
29 | 25 |
26 Note: A dataset can recognize a potentially infinite number of field names (i.e. the field | |
27 values can be computed on-demand, when particular field names are used in one of the | |
28 iterators). | |
29 | |
30 Datasets of finite length should be sub-classes of FiniteLengthDataSet. | |
31 | |
32 Datasets whose elements can be indexed and sub-datasets of consecutive | |
33 examples (i.e. slices) can be extracted from should be sub-classes of | |
34 SliceableDataSet. | |
35 | |
36 Datasets with a finite number of fields should be sub-classes of | |
37 FiniteWidthDataSet. | |
30 """ | 38 """ |
31 | 39 |
32 def __init__(self): | 40 def __init__(self): |
33 pass | 41 pass |
34 | 42 |
43 class Iter(LookupList): | |
44 def __init__(self, ll): | |
45 LookupList.__init__(self, ll.keys(), ll.values()) | |
46 self.ll = ll | |
47 def __iter__(self): #makes for loop work | |
48 return self | |
49 def next(self): | |
50 self.ll.next() | |
51 self._values = [v[0] for v in self.ll._values] | |
52 return self | |
53 | |
35 def __iter__(self): | 54 def __iter__(self): |
36 """Supports the syntax "for i in dataset: ..." | 55 """Supports the syntax "for i in dataset: ..." |
37 | 56 |
38 Using this syntax, "i" will be an Example instance (or equivalent) with | 57 Using this syntax, "i" will be an Example instance (or equivalent) with |
39 all the fields of DataSet self. Every field of "i" will give access to | 58 all the fields of DataSet self. Every field of "i" will give access to |
40 a field of a single example. Fields should be accessible via | 59 a field of a single example. Fields should be accessible via |
41 i["fielname"] or i[3] (in the fieldNames() order), but the derived class is free | 60 i["fielname"] or i[3] (in the order defined by the elements of the |
61 Example returned by this iterator), but the derived class is free | |
42 to accept any type of identifier, and add extra functionality to the iterator. | 62 to accept any type of identifier, and add extra functionality to the iterator. |
43 """ | 63 """ |
44 return self.zip(*self.fieldNames()) | 64 return DataSet.Iter(self.minibatches(None, minibatch_size = 1)) |
45 | 65 |
46 def zip(self, *fieldnames): | 66 def zip(self, *fieldnames): |
47 """ | 67 """ |
48 Supports two forms of syntax: | 68 Supports two forms of syntax: |
49 | 69 |
59 f1, f2, and f3 fields of a single example on each loop iteration. | 79 f1, f2, and f3 fields of a single example on each loop iteration. |
60 | 80 |
61 The derived class may accept fieldname arguments of any type. | 81 The derived class may accept fieldname arguments of any type. |
62 | 82 |
63 """ | 83 """ |
64 class Iter(LookupList): | 84 return DataSet.Iter(self.minibatches(fieldnames, minibatch_size = 1)) |
65 def __init__(self, ll): | |
66 LookupList.__init__(self, ll.keys(), ll.values()) | |
67 self.ll = ll | |
68 def __iter__(self): #makes for loop work | |
69 return self | |
70 def next(self): | |
71 self.ll.next() | |
72 self._values = [v[0] for v in self.ll._values] | |
73 return self | |
74 return Iter(self.minibatches(fieldnames, minibatch_size = 1)) | |
75 | 85 |
76 minibatches_fieldnames = None | 86 minibatches_fieldnames = None |
77 minibatches_minibatch_size = 1 | 87 minibatches_minibatch_size = 1 |
78 minibatches_n_batches = None | 88 minibatches_n_batches = None |
79 def minibatches(self, | 89 def minibatches(self, |
80 fieldnames = minibatches_fieldnames, | 90 fieldnames = minibatches_fieldnames, |
81 minibatch_size = minibatches_minibatch_size, | 91 minibatch_size = minibatches_minibatch_size, |
82 n_batches = minibatches_n_batches): | 92 n_batches = minibatches_n_batches): |
83 """ | 93 """ |
84 Supports two forms of syntax: | 94 Supports three forms of syntax: |
95 | |
96 for i in dataset.minibatches(None,**kwargs): ... | |
85 | 97 |
86 for i in dataset.minibatches([f1, f2, f3],**kwargs): ... | 98 for i in dataset.minibatches([f1, f2, f3],**kwargs): ... |
87 | 99 |
88 for i1, i2, i3 in dataset.minibatches([f1, f2, f3],**kwargs): ... | 100 for i1, i2, i3 in dataset.minibatches([f1, f2, f3],**kwargs): ... |
89 | 101 |
90 Using the first syntax, "i" will be an indexable object, such as a list, | 102 Using the first two syntaxes, "i" will be an indexable object, such as a list, |
91 tuple, or Example instance, such that on every iteration, i[0] is a | 103 tuple, or Example instance. In both cases, i[k] is a list-like container |
104 of a batch of current examples. In the second case, i[0] is | |
92 list-like container of the f1 field of a batch current examples, i[1] is | 105 list-like container of the f1 field of a batch current examples, i[1] is |
93 a list-like container of the f2 field, etc. | 106 a list-like container of the f2 field, etc. |
94 | 107 |
95 Using the second syntax, i1, i2, i3 will be list-like containers of the | 108 Using the first syntax, all the fields will be returned in "i". |
109 Beware that some datasets may not support this syntax, if the number | |
110 of fields is infinite (i.e. field values may be computed "on demand"). | |
111 | |
112 Using the third syntax, i1, i2, i3 will be list-like containers of the | |
96 f1, f2, and f3 fields of a batch of examples on each loop iteration. | 113 f1, f2, and f3 fields of a batch of examples on each loop iteration. |
97 | 114 |
98 PARAMETERS | 115 PARAMETERS |
99 - fieldnames (list of any type, default None): | 116 - fieldnames (list of any type, default None): |
100 The loop variables i1, i2, i3 (in the example above) should contain the | 117 The loop variables i1, i2, i3 (in the example above) should contain the |
113 Note: A list-like container is something like a tuple, list, numpy.ndarray or | 130 Note: A list-like container is something like a tuple, list, numpy.ndarray or |
114 any other object that supports integer indexing and slicing. | 131 any other object that supports integer indexing and slicing. |
115 | 132 |
116 """ | 133 """ |
117 raise AbstractFunction() | 134 raise AbstractFunction() |
118 | 135 |
119 def fieldNames(self): | 136 def hasFields(*fieldnames): |
120 #Yoshua- | 137 """ |
121 # This list may not be finite; what would make sense in the use you have | 138 Return true if the given field name (or field names, if multiple arguments are |
122 # in mind? | 139 given) is recognized by the DataSet (i.e. can be used as a field name in one |
123 # -JB | 140 of the iterators). |
124 #James- | 141 """ |
125 # You are right. I had put this to be able to iterate over the fields | 142 raise AbstractFunction() |
126 # but maybe an iterator mechanism (over fields rather than examples) | 143 |
127 # would be more appropriate. Fieldnames are needed in general | |
128 # by the iterators over examples or minibatches, to construct | |
129 # examples or minibatches with the corresponding names as attributes. | |
130 # -YB | |
131 """ | |
132 Return an iterator (an object with an __iter__ method) that | |
133 iterates over the names of the fields. As a special cases, | |
134 a list or a tuple of field names can be returned. | |
135 """" | |
136 # Note that some datasets | |
137 # may have virtual fields and support a virtually infinite number | |
138 # of possible field names. In that case, fieldNames() should | |
139 # either raise an error or iterate over a particular set of | |
140 # names as appropriate. Another option would be to iterate | |
141 # over the sub-datasets comprising a single field at a time. | |
142 # I am not sure yet what is most appropriate. | |
143 # -YB | |
144 """ | |
145 raise AbstractFunction() | |
146 | |
147 def rename(*new_field_specifications): | 144 def rename(*new_field_specifications): |
148 #Yoshua- | 145 #Yoshua- |
149 # Do you mean for this to be a virtual method? | 146 # Do you mean for this to be a virtual method? |
150 # Wouldn't this functionality be easier to provide via a | 147 # Wouldn't this functionality be easier to provide via a |
151 # RenamingDataSet, such as the one I've written below? | 148 # RenamingDataSet, such as the one I've written below? |
163 of a matrix-like field). | 160 of a matrix-like field). |
164 """ | 161 """ |
165 raise AbstractFunction() | 162 raise AbstractFunction() |
166 | 163 |
167 | 164 |
168 def apply_function(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): | 165 def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): |
169 """ | 166 """ |
170 Return a dataset that contains as fields the results of applying | 167 Return a dataset that contains as fields the results of applying |
171 the given function (example-wise) to the specified input_fields. The | 168 the given function (example-wise) to the specified input_fields. The |
172 function should return a sequence whose elements will be stored in | 169 function should return a sequence whose elements will be stored in |
173 fields whose names are given in the output_fields list. If copy_inputs | 170 fields whose names are given in the output_fields list. If copy_inputs |
200 n_batches = DataSet.minibatches_n_batches): | 197 n_batches = DataSet.minibatches_n_batches): |
201 dct = self.rename_dct | 198 dct = self.rename_dct |
202 new_fieldnames = [dct.get(f, f) for f in fieldnames] | 199 new_fieldnames = [dct.get(f, f) for f in fieldnames] |
203 return self.src.minibatches(new_fieldnames, minibatches_size, n_batches) | 200 return self.src.minibatches(new_fieldnames, minibatches_size, n_batches) |
204 | 201 |
205 def fieldNames(self): | 202 class FiniteLengthDataSet(DataSet): |
206 return [dct.get(f, f) for f in self.src.fieldNames()] | 203 """ |
207 | 204 Virtual interface for datasets that have a finite length (number of examples), |
208 | 205 and thus recognize a len(dataset) call. |
209 class FiniteDataSet(DataSet): | 206 """ |
210 """ | 207 def __init__(self): |
211 Virtual interface, a subclass of DataSet for datasets which have a finite, known length. | 208 DataSet.__init__(self) |
212 Examples are indexed by an integer between 0 and self.length()-1, | 209 |
213 and a subdataset can be obtained by slicing. This may not be appropriate in general | 210 def __len__(self): |
214 but only for datasets which can be thought of like ones that access rows AND fields | 211 """len(dataset) returns the number of examples in the dataset.""" |
215 in an efficient random access way. Users are encouraged to expect only the generic dataset | 212 raise AbstractFunction() |
216 interface in general. A FiniteDataSet is mainly useful when one has to obtain | 213 |
217 a subset of examples (e.g. for splitting a dataset into training and test sets). | 214 |
218 """ | 215 class SliceableDataSet(DataSet): |
219 | 216 """ |
220 class FiniteDataSetIterator(object): | 217 Virtual interface, a subclass of DataSet for datasets which are sliceable |
221 """ | 218 and whose individual elements can be accessed, generally respecting the |
222 If the fieldnames list is empty, it means that we want to see ALL the fields. | 219 python semantics for [spec], where spec is either a non-negative integer |
223 """ | 220 (for selecting one example), or a python slice (for selecting a sub-dataset |
224 def __init__(self,dataset,minibatch_size=1,fieldnames=[]): | 221 comprising the specified examples). This is useful for obtaining |
225 self.dataset=dataset | 222 sub-datasets, e.g. for splitting a dataset into training and test sets. |
226 self.minibatch_size=minibatch_size | 223 """ |
227 assert minibatch_size>=1 and minibatch_size<=len(dataset) | 224 def __init__(self): |
228 self.current = -self.minibatch_size | 225 DataSet.__init__(self) |
229 self.fieldnames = fieldnames | |
230 if len(dataset) % minibatch_size: | |
231 raise NotImplementedError() | |
232 | |
233 def __iter__(self): | |
234 return self | |
235 | 226 |
236 def next(self): | |
237 self.current+=self.minibatch_size | |
238 if self.current>=len(self.dataset): | |
239 self.current=-self.minibatch_size | |
240 raise StopIteration | |
241 if self.minibatch_size==1: | |
242 complete_example=self.dataset[self.current] | |
243 else: | |
244 complete_example=self.dataset[self.current:self.current+self.minibatch_size] | |
245 if self.fieldnames: | |
246 return Example(self.fieldnames,list(complete_example)) | |
247 else: | |
248 return complete_example | |
249 | |
250 def __init__(self): | |
251 pass | |
252 | |
253 def minibatches(self, | 227 def minibatches(self, |
254 fieldnames = DataSet.minibatches_fieldnames, | 228 fieldnames = DataSet.minibatches_fieldnames, |
255 minibatch_size = DataSet.minibatches_minibatch_size, | 229 minibatch_size = DataSet.minibatches_minibatch_size, |
256 n_batches = DataSet.minibatches_n_batches): | 230 n_batches = DataSet.minibatches_n_batches): |
257 """ | 231 """ |
258 If the fieldnames list is empty, it means that we want to see ALL the fields. | |
259 | |
260 If the n_batches is empty, we want to see all the examples possible | 232 If the n_batches is empty, we want to see all the examples possible |
261 for the give minibatch_size. | 233 for the given minibatch_size (possibly missing a few at the end of the dataset). |
262 """ | 234 """ |
263 # substitute the defaults: | 235 # substitute the defaults: |
264 if fieldnames is None: fieldnames = self.fieldNames() | |
265 if n_batches is None: n_batches = len(self) / minibatch_size | 236 if n_batches is None: n_batches = len(self) / minibatch_size |
266 return DataSet.Iterator(self, fieldnames, minibatch_size, n_batches) | 237 return DataSet.Iterator(self, fieldnames, minibatch_size, n_batches) |
267 | 238 |
268 def __getattr__(self,fieldname): | |
269 """Return an that can iterate over the values of the field in this dataset.""" | |
270 return self(fieldname) | |
271 | |
272 def __call__(self,*fieldnames): | |
273 """Return a sub-dataset containing only the given fieldnames as fields. | |
274 | |
275 The return value's default iterator will iterate only over the given | |
276 fields. | |
277 """ | |
278 raise AbstractFunction() | |
279 | |
280 def __len__(self): | |
281 """len(dataset) returns the number of examples in the dataset.""" | |
282 raise AbstractFunction() | |
283 | |
284 def __getitem__(self,i): | 239 def __getitem__(self,i): |
285 """dataset[i] returns the (i+1)-th example of the dataset.""" | 240 """dataset[i] returns the (i+1)-th example of the dataset.""" |
286 raise AbstractFunction() | 241 raise AbstractFunction() |
287 | 242 |
288 def __getslice__(self,*slice_args): | 243 def __getslice__(self,*slice_args): |
289 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | 244 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" |
290 raise AbstractFunction() | 245 raise AbstractFunction() |
246 | |
247 | |
248 class FiniteWidthDataSet(DataSet): | |
249 """ | |
250 Virtual interface for datasets that have a finite width (number of fields), | |
251 and thus return a list of fieldNames. | |
252 """ | |
253 def __init__(self): | |
254 DataSet.__init__(self) | |
255 | |
256 def hasFields(*fieldnames): | |
257 has_fields=True | |
258 for fieldname in fieldnames: | |
259 if fieldname not in self.fields.keys(): | |
260 has_fields=False | |
261 return has_fields | |
262 | |
263 def fieldNames(self): | |
264 """Return the list of field names that are supported by the iterators, | |
265 and for which hasFields(fieldname) would return True.""" | |
266 raise AbstractFunction() | |
267 | |
291 | 268 |
292 # we may want ArrayDataSet defined in another python file | 269 # we may want ArrayDataSet defined in another python file |
293 | 270 |
294 import numpy | 271 import numpy |
295 | 272 |
324 # many complicated things remain to be done: | 301 # many complicated things remain to be done: |
325 # - find common dtype | 302 # - find common dtype |
326 # - decide what to do with extra dimensions if not the same in all fields | 303 # - decide what to do with extra dimensions if not the same in all fields |
327 # - try to see if we can avoid the copy? | 304 # - try to see if we can avoid the copy? |
328 | 305 |
329 class ArrayDataSet(FiniteDataSet): | 306 class ArrayDataSet(FiniteLengthDataSet,FiniteWidthDataSet,SliceableDataSet): |
330 """ | 307 """ |
331 An ArrayDataSet behaves like a numpy array but adds the notion of named fields | 308 An ArrayDataSet behaves like a numpy array but adds the notion of named fields |
332 from DataSet (and the ability to view the values of multiple fields as an 'Example'). | 309 from DataSet (and the ability to view the values of multiple fields as an 'Example'). |
333 It is a fixed-length and fixed-width dataset | 310 It is a fixed-length and fixed-width dataset |
334 in which each element is a fixed dimension numpy array or a number, hence the whole | 311 in which each element is a fixed dimension numpy array or a number, hence the whole |
340 """ | 317 """ |
341 | 318 |
342 class Iterator(LookupList): | 319 class Iterator(LookupList): |
343 """An iterator over a finite dataset that implements wrap-around""" | 320 """An iterator over a finite dataset that implements wrap-around""" |
344 def __init__(self, dataset, fieldnames, minibatch_size, next_max): | 321 def __init__(self, dataset, fieldnames, minibatch_size, next_max): |
345 LookupList.__init__(self, fieldnames, [0] * len(fieldnames)) | 322 if fieldnames is None: fieldnames = dataset.fieldNames() |
323 LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) | |
346 self.dataset=dataset | 324 self.dataset=dataset |
347 self.minibatch_size=minibatch_size | 325 self.minibatch_size=minibatch_size |
348 self.next_count = 0 | 326 self.next_count = 0 |
349 self.next_max = next_max | 327 self.next_max = next_max |
350 self.current = -self.minibatch_size | 328 self.current = -self.minibatch_size |
390 dataview = data[self.current:] | 368 dataview = data[self.current:] |
391 upper -= rows | 369 upper -= rows |
392 assert upper > 0 | 370 assert upper > 0 |
393 dataview = self.matcat(dataview, data[:upper]) | 371 dataview = self.matcat(dataview, data[:upper]) |
394 | 372 |
395 | |
396 self._values = [dataview[:, self.dataset.fields[f]]\ | 373 self._values = [dataview[:, self.dataset.fields[f]]\ |
397 for f in self._names] | 374 for f in self._names] |
398 | |
399 return self | 375 return self |
400 | 376 |
401 | 377 |
402 def __init__(self, data, fields=None): | 378 def __init__(self, data, fields=None): |
403 """ | 379 """ |
427 def minibatches(self, | 403 def minibatches(self, |
428 fieldnames = DataSet.minibatches_fieldnames, | 404 fieldnames = DataSet.minibatches_fieldnames, |
429 minibatch_size = DataSet.minibatches_minibatch_size, | 405 minibatch_size = DataSet.minibatches_minibatch_size, |
430 n_batches = DataSet.minibatches_n_batches): | 406 n_batches = DataSet.minibatches_n_batches): |
431 """ | 407 """ |
432 If the fieldnames list is empty, it means that we want to see ALL the fields. | 408 If the fieldnames list is None, it means that we want to see ALL the fields. |
433 | 409 |
434 If the n_batches is empty, we want to see all the examples possible | 410 If the n_batches is None, we want to see all the examples possible |
435 for the give minibatch_size. | 411 for the given minibatch_size (possibly missing some near the end). |
436 """ | 412 """ |
437 # substitute the defaults: | 413 # substitute the defaults: |
438 if fieldnames is None: fieldnames = self.fieldNames() | |
439 if n_batches is None: n_batches = len(self) / minibatch_size | 414 if n_batches is None: n_batches = len(self) / minibatch_size |
440 return ArrayDataSet.Iterator(self, fieldnames, minibatch_size, n_batches) | 415 return ArrayDataSet.Iterator(self, fieldnames, minibatch_size, n_batches) |
441 | 416 |
442 def __getattr__(self,fieldname): | 417 def __getattr__(self,fieldname): |
443 """ | 418 """ |
460 for fieldname,fieldslice in self.fields.items(): | 435 for fieldname,fieldslice in self.fields.items(): |
461 new_fields[fieldname]=slice(fieldslice.start-min_col,fieldslice.stop-min_col,fieldslice.step) | 436 new_fields[fieldname]=slice(fieldslice.start-min_col,fieldslice.stop-min_col,fieldslice.step) |
462 return ArrayDataSet(self.data[:,min_col:max_col],fields=new_fields) | 437 return ArrayDataSet(self.data[:,min_col:max_col],fields=new_fields) |
463 | 438 |
464 def fieldNames(self): | 439 def fieldNames(self): |
465 """Return the list of field names that are supported by getattr and getFields.""" | 440 """Return the list of field names that are supported by getattr and hasField.""" |
466 return self.fields.keys() | 441 return self.fields.keys() |
467 | 442 |
468 def __len__(self): | 443 def __len__(self): |
469 """len(dataset) returns the number of examples in the dataset.""" | 444 """len(dataset) returns the number of examples in the dataset.""" |
470 return len(self.data) | 445 return len(self.data) |
500 """ | 475 """ |
501 if not self.fields: | 476 if not self.fields: |
502 return self.data | 477 return self.data |
503 # else, select subsets of columns mapped by the fields | 478 # else, select subsets of columns mapped by the fields |
504 columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) | 479 columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) |
480 overlapping_fields = False | |
481 n_columns = 0 | |
505 for field_slice in self.fields.values(): | 482 for field_slice in self.fields.values(): |
506 for c in xrange(field_slice.start,field_slice.stop,field_slice.step): | 483 for c in xrange(field_slice.start,field_slice.stop,field_slice.step): |
484 n_columns += 1 | |
485 if columns_used[c]: overlapping_fields=True | |
507 columns_used[c]=True | 486 columns_used[c]=True |
508 # try to figure out if we can map all the slices into one slice: | 487 # try to figure out if we can map all the slices into one slice: |
509 mappable_to_one_slice = True | 488 mappable_to_one_slice = not overlapping_fields |
510 start=0 | 489 if not overlapping_fields: |
511 while start<len(columns_used) and not columns_used[start]: | 490 start=0 |
512 start+=1 | 491 while start<len(columns_used) and not columns_used[start]: |
513 stop=len(columns_used) | 492 start+=1 |
514 while stop>0 and not columns_used[stop-1]: | 493 stop=len(columns_used) |
515 stop-=1 | 494 while stop>0 and not columns_used[stop-1]: |
516 step=0 | 495 stop-=1 |
517 i=start | 496 step=0 |
518 while i<stop: | 497 i=start |
519 j=i+1 | 498 while i<stop: |
520 while j<stop and not columns_used[j]: | 499 j=i+1 |
521 j+=1 | 500 while j<stop and not columns_used[j]: |
522 if step: | 501 j+=1 |
523 if step!=j-i: | 502 if step: |
524 mappable_to_one_slice = False | 503 if step!=j-i: |
525 break | 504 mappable_to_one_slice = False |
526 else: | 505 break |
527 step = j-i | 506 else: |
528 i=j | 507 step = j-i |
508 i=j | |
529 if mappable_to_one_slice: | 509 if mappable_to_one_slice: |
530 return self.data[:,slice(start,stop,step)] | 510 return self.data[:,slice(start,stop,step)] |
531 # else make contiguous copy | 511 # else make contiguous copy (copying the overlapping columns) |
532 n_columns = sum(columns_used) | 512 result = numpy.zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype) |
533 result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype) | |
534 print result.shape | |
535 c=0 | 513 c=0 |
536 for field_slice in self.fields.values(): | 514 for field_slice in self.fields.values(): |
537 slice_width=field_slice.stop-field_slice.start/field_slice.step | 515 slice_width=(field_slice.stop-field_slice.start)/field_slice.step |
538 # copy the field here | 516 # copy the field here |
539 result[:,slice(c,slice_width)]=self.data[:,field_slice] | 517 result[:,slice(c,c+slice_width)]=self.data[:,field_slice] |
540 c+=slice_width | 518 c+=slice_width |
541 return result | 519 return result |
542 | 520 |
543 class ApplyFunctionDataset(DataSet): | 521 class ApplyFunctionDataSet(DataSet): |
544 """ | 522 """ |
545 A dataset that contains as fields the results of applying | 523 A dataset that contains as fields the results of applying |
546 a given function (example-wise) to specified input_fields of a source | 524 a given function (example-wise) to specified input_fields of a source |
547 dataset. The function should return a sequence whose elements will be stored in | 525 dataset. The function should return a sequence whose elements will be stored in |
548 fields whose names are given in the output_fields list. If copy_inputs | 526 fields whose names are given in the output_fields list. If copy_inputs |
581 | 559 |
582 def minibatches(self, | 560 def minibatches(self, |
583 fieldnames = DataSet.minibatches_fieldnames, | 561 fieldnames = DataSet.minibatches_fieldnames, |
584 minibatch_size = DataSet.minibatches_minibatch_size, | 562 minibatch_size = DataSet.minibatches_minibatch_size, |
585 n_batches = DataSet.minibatches_n_batches): | 563 n_batches = DataSet.minibatches_n_batches): |
586 | 564 |
587 class Iterator(LookupList): | 565 class Iterator(LookupList): |
588 | 566 |
589 def __init__(self,dataset): | 567 def __init__(self,dataset): |
590 LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) | 568 if fieldnames is None: |
569 LookupList.__init__(self, [],[]) | |
570 else: | |
571 LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) | |
591 self.dataset=dataset | 572 self.dataset=dataset |
592 if dataset.copy_inputs: | 573 self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(self.dataset.input_fields))), |
593 src_fields=dataset.fieldNames() | 574 minibatch_size,n_batches) |
594 else: | |
595 src_fields=dataset.input_fields | |
596 self.src_iterator=self.src.minibatches(src_fields,minibatch_size,n_batches) | |
597 | 575 |
598 def __iter__(self): | 576 def __iter__(self): |
599 return self | 577 return self |
600 | 578 |
601 def next(self): | 579 def next(self): |
602 src_examples = self.src_iterator.next() | 580 src_examples = self.src_iterator.next() |
603 if self.dataset.copy_inputs: | 581 if self.dataset.copy_inputs: |
604 function_inputs = src_examples | 582 function_inputs = src_examples |
605 else: | 583 else: |
606 function_inputs = | 584 function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields] |
607 [src_examples[field_name] for field_name in self.dataset.input_fields]) | 585 outputs = Example(self.dataset.output_fields,self.dataset.function(*function_inputs)) |
608 return self.dataset.function(*function_inputs) | 586 if self.dataset.copy_inputs: |
587 return src_examples + outputs | |
588 else: | |
589 return outputs | |
609 | 590 |
610 for fieldname in fieldnames: | 591 for fieldname in fieldnames: |
611 assert fieldname in self.input_fields | 592 assert fieldname in self.output_fields or self.src.hasFields(fieldname) |
612 return Iterator(self) | 593 return Iterator(self) |
613 | 594 |
614 | 595 |