Mercurial > pylearn
comparison dataset.py @ 36:438440ba0627
Rewriting dataset.py completely
author | bengioy@zircon.iro.umontreal.ca |
---|---|
date | Tue, 22 Apr 2008 18:03:11 -0400 |
parents | 46c5c90019c2 |
children | 73c4212ba5b3 |
comparison
equal
deleted
inserted
replaced
35:2508c373cf29 | 36:438440ba0627 |
---|---|
1 | 1 |
2 from lookup_list import LookupList | 2 from lookup_list import LookupList |
3 Example = LookupList | 3 Example = LookupList |
4 from misc import * | |
4 import copy | 5 import copy |
5 | 6 |
6 class AbstractFunction (Exception): """Derived class must override this function""" | 7 class AbstractFunction (Exception): """Derived class must override this function""" |
7 | 8 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" |
9 | |
8 class DataSet(object): | 10 class DataSet(object): |
9 """A virtual base class for datasets. | 11 """A virtual base class for datasets. |
10 | 12 |
13 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction | |
14 with learning algorithms (for training and testing them): rows/records are called examples, and | |
15 columns/attributes are called fields. The field value for a particular example can be an arbitrary | |
16 python object, which depends on the particular dataset. | |
17 | |
18 We call a DataSet a 'stream' when its length is unbounded (len(dataset)==float("infinity")). | |
19 | |
11 A DataSet is a generator of iterators; these iterators can run through the | 20 A DataSet is a generator of iterators; these iterators can run through the |
12 examples in a variety of ways. A DataSet need not necessarily have a finite | 21 examples or the fields in a variety of ways. A DataSet need not necessarily have a finite |
13 or known length, so this class can be used to interface to a 'stream' which | 22 or known length, so this class can be used to interface to a 'stream' which |
14 feeds on-line learning. | 23 feeds on-line learning (however, as noted below, some operations are not |
24 feasible or not recommanded on streams). | |
15 | 25 |
16 To iterate over examples, there are several possibilities: | 26 To iterate over examples, there are several possibilities: |
17 - for example in dataset.zip([field1, field2,field3, ...]) | 27 * for example in dataset([field1, field2,field3, ...]): |
18 - for val1,val2,val3 in dataset.zip([field1, field2,field3]) | 28 * for val1,val2,val3 in dataset([field1, field2,field3]): |
19 - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N) | 29 * for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): |
20 - for example in dataset | 30 * for example in dataset: |
21 Each of these is documented below. All of these iterators are expected | 31 Each of these is documented below. All of these iterators are expected |
22 to provide, in addition to the usual 'next()' method, a 'next_index()' method | 32 to provide, in addition to the usual 'next()' method, a 'next_index()' method |
23 which returns a non-negative integer pointing to the position of the next | 33 which returns a non-negative integer pointing to the position of the next |
24 example that will be returned by 'next()' (or of the first example in the | 34 example that will be returned by 'next()' (or of the first example in the |
25 next minibatch returned). This is important because these iterators | 35 next minibatch returned). This is important because these iterators |
26 can wrap around the dataset in order to do multiple passes through it, | 36 can wrap around the dataset in order to do multiple passes through it, |
27 in possibly unregular ways if the minibatch size is not a divisor of the | 37 in possibly unregular ways if the minibatch size is not a divisor of the |
28 dataset length. | 38 dataset length. |
29 | 39 |
40 To iterate over fields, one can do | |
41 * for fields in dataset.fields() | |
42 * for fields in dataset(field1,field2,...).fields() to select a subset of fields | |
43 * for fields in dataset.fields(field1,field2,...) to select a subset of fields | |
44 and each of these fields is iterable over the examples: | |
45 * for field_examples in dataset.fields(): | |
46 for example_value in field_examples: | |
47 ... | |
48 but when the dataset is a stream (unbounded length), it is not recommanded to do | |
49 such things because the underlying dataset may refuse to access the different fields in | |
50 an unsynchronized ways. Hence the fields() method is illegal for streams, by default. | |
51 The result of fields() is a DataSetFields object, which iterates over fields, | |
52 and whose elements are iterable over examples. A DataSetFields object can | |
53 be turned back into a DataSet with its examples() method: | |
54 dataset2 = dataset1.fields().examples() | |
55 and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). | |
56 | |
30 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. | 57 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. |
31 | 58 |
32 Note: The content of a field can be of any type. | 59 Note: The content of a field can be of any type. Field values can also be 'missing' |
33 | 60 (e.g. to handle semi-supervised learning), and in the case of numeric (numpy array) |
34 Note: A dataset can recognize a potentially infinite number of field names (i.e. the field | 61 fields (i.e. an ArrayFieldsDataSet), NaN plays the role of a missing value. |
35 values can be computed on-demand, when particular field names are used in one of the | 62 |
36 iterators). | 63 Dataset elements can be indexed and sub-datasets (with a subset |
37 | 64 of examples) can be extracted. These operations are not supported |
38 Datasets of finite length should be sub-classes of FiniteLengthDataSet. | 65 by default in the case of streams. |
39 | 66 |
40 Datasets whose elements can be indexed and whose sub-datasets (with a subset | 67 * dataset[:n] returns a dataset with the n first examples. |
41 of examples) can be extracted should be sub-classes of | 68 |
42 SliceableDataSet. | 69 * dataset[i1:i2:s] returns a dataset with the examples i1,i1+s,...i2-s. |
43 | 70 |
44 Datasets with a finite number of fields should be sub-classes of | 71 * dataset[i] returns an Example. |
45 FiniteWidthDataSet. | 72 |
46 """ | 73 * dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. |
47 | 74 |
75 Datasets can be concatenated either vertically (increasing the length) or | |
76 horizontally (augmenting the set of fields), if they are compatible, using | |
77 the following operations (with the same basic semantics as numpy.hstack | |
78 and numpy.vstack): | |
79 | |
80 * dataset1 | dataset2 | dataset3 == dataset.hstack([dataset1,dataset2,dataset3]) | |
81 | |
82 creates a new dataset whose list of fields is the concatenation of the list of | |
83 fields of the argument datasets. This only works if they all have the same length. | |
84 | |
85 * dataset1 + dataset2 + dataset3 == dataset.vstack([dataset1,dataset2,dataset3]) | |
86 | |
87 creates a new dataset that concatenates the examples from the argument datasets | |
88 (and whose length is the sum of the length of the argument datasets). This only | |
89 works if they all have the same fields. | |
90 | |
91 According to the same logic, and viewing a DataSetFields object associated to | |
92 a DataSet as a kind of transpose of it, fields1 + fields2 concatenates fields of | |
93 a DataSetFields fields1 and fields2, and fields1 | fields2 concatenates their | |
94 examples. | |
95 | |
96 | |
97 A DataSet sub-class should always redefine the following methods: | |
98 * __len__ if it is not a stream | |
99 * __getitem__ may not be feasible with some streams | |
100 * fieldNames | |
101 * minibatches | |
102 * valuesHStack | |
103 * valuesVStack | |
104 For efficiency of implementation, a sub-class might also want to redefine | |
105 * hasFields | |
106 """ | |
107 | |
108 infinity = float("infinity") | |
109 | |
48 def __init__(self): | 110 def __init__(self): |
49 pass | 111 pass |
50 | 112 |
51 class Iterator(LookupList): | 113 class MinibatchToSingleExampleIterator(object): |
52 def __init__(self, ll): | 114 """ |
53 LookupList.__init__(self, ll.keys(), ll.values()) | 115 Converts the result of minibatch iterator with minibatch_size==1 into |
54 self.ll = ll | 116 single-example values in the result. Therefore the result of |
117 iterating on the dataset itself gives a sequence of single examples | |
118 (whereas the result of iterating over minibatches gives in each | |
119 Example field an iterable object over the individual examples in | |
120 the minibatch). | |
121 """ | |
122 def __init__(self, minibatch_iterator): | |
123 self.minibatch_iterator = minibatch_iterator | |
55 def __iter__(self): #makes for loop work | 124 def __iter__(self): #makes for loop work |
56 return self | 125 return self |
57 def next(self): | 126 def next(self): |
58 self.ll.next() | 127 return self.minibatch_iterator.next()[0] |
59 self._values = [v[0] for v in self.ll._values] | |
60 return self | |
61 def next_index(self): | 128 def next_index(self): |
62 return self.ll.next_index() | 129 return self.minibatch_iterator.next_index() |
63 | 130 |
64 def __iter__(self): | 131 def __iter__(self): |
65 """Supports the syntax "for i in dataset: ..." | 132 """Supports the syntax "for i in dataset: ..." |
66 | 133 |
67 Using this syntax, "i" will be an Example instance (or equivalent) with | 134 Using this syntax, "i" will be an Example instance (or equivalent) with |
68 all the fields of DataSet self. Every field of "i" will give access to | 135 all the fields of DataSet self. Every field of "i" will give access to |
69 a field of a single example. Fields should be accessible via | 136 a field of a single example. Fields should be accessible via |
70 i["fielname"] or i[3] (in the order defined by the elements of the | 137 i["fielname"] or i[3] (in the order defined by the elements of the |
71 Example returned by this iterator), but the derived class is free | 138 Example returned by this iterator), but the derived class is free |
72 to accept any type of identifier, and add extra functionality to the iterator. | 139 to accept any type of identifier, and add extra functionality to the iterator. |
73 """ | 140 |
74 return DataSet.Iterator(self.minibatches(None, minibatch_size = 1)) | 141 The default implementation calls the minibatches iterator and extracts the first example of each field. |
75 | 142 """ |
76 def zip(self, *fieldnames): | 143 return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1)) |
77 """ | |
78 Supports two forms of syntax: | |
79 | |
80 for i in dataset.zip([f1, f2, f3]): ... | |
81 | |
82 for i1, i2, i3 in dataset.zip([f1, f2, f3]): ... | |
83 | |
84 Using the first syntax, "i" will be an indexable object, such as a list, | |
85 tuple, or Example instance, such that on every iteration, i[0] is the f1 | |
86 field of the current example, i[1] is the f2 field, and so on. | |
87 | |
88 Using the second syntax, i1, i2, i3 will contain the the contents of the | |
89 f1, f2, and f3 fields of a single example on each loop iteration. | |
90 | |
91 The derived class may accept fieldname arguments of any type. | |
92 | |
93 """ | |
94 return DataSet.Iterator(self.minibatches(fieldnames, minibatch_size = 1)) | |
95 | 144 |
96 minibatches_fieldnames = None | 145 minibatches_fieldnames = None |
97 minibatches_minibatch_size = 1 | 146 minibatches_minibatch_size = 1 |
98 minibatches_n_batches = None | 147 minibatches_n_batches = None |
99 def minibatches(self, | 148 def minibatches(self, |
100 fieldnames = minibatches_fieldnames, | 149 fieldnames = minibatches_fieldnames, |
101 minibatch_size = minibatches_minibatch_size, | 150 minibatch_size = minibatches_minibatch_size, |
102 n_batches = minibatches_n_batches): | 151 n_batches = minibatches_n_batches): |
103 """ | 152 """ |
104 Supports three forms of syntax: | 153 Return an iterator that supports three forms of syntax: |
105 | 154 |
106 for i in dataset.minibatches(None,**kwargs): ... | 155 for i in dataset.minibatches(None,**kwargs): ... |
107 | 156 |
108 for i in dataset.minibatches([f1, f2, f3],**kwargs): ... | 157 for i in dataset.minibatches([f1, f2, f3],**kwargs): ... |
109 | 158 |
120 of fields is infinite (i.e. field values may be computed "on demand"). | 169 of fields is infinite (i.e. field values may be computed "on demand"). |
121 | 170 |
122 Using the third syntax, i1, i2, i3 will be list-like containers of the | 171 Using the third syntax, i1, i2, i3 will be list-like containers of the |
123 f1, f2, and f3 fields of a batch of examples on each loop iteration. | 172 f1, f2, and f3 fields of a batch of examples on each loop iteration. |
124 | 173 |
174 The minibatches iterator is expected to return upon each call to next() | |
175 a DataSetFields object, which is a LookupList (indexed by the field names) whose | |
176 elements are iterable over the minibatch examples, and which keeps a pointer to | |
177 a sub-dataset that can be used to iterate over the individual examples | |
178 in the minibatch. Hence a minibatch can be converted back to a regular | |
179 dataset or its fields can be looked at individually (and possibly iterated over). | |
180 | |
125 PARAMETERS | 181 PARAMETERS |
126 - fieldnames (list of any type, default None): | 182 - fieldnames (list of any type, default None): |
127 The loop variables i1, i2, i3 (in the example above) should contain the | 183 The loop variables i1, i2, i3 (in the example above) should contain the |
128 f1, f2, and f3 fields of the current batch of examples. If None, the | 184 f1, f2, and f3 fields of the current batch of examples. If None, the |
129 derived class can choose a default, e.g. all fields. | 185 derived class can choose a default, e.g. all fields. |
141 any other object that supports integer indexing and slicing. | 197 any other object that supports integer indexing and slicing. |
142 | 198 |
143 """ | 199 """ |
144 raise AbstractFunction() | 200 raise AbstractFunction() |
145 | 201 |
202 | |
203 def __len__(self): | |
204 """ | |
205 len(dataset) returns the number of examples in the dataset. | |
206 By default, a DataSet is a 'stream', i.e. it has an unbounded (infinite) length. | |
207 Sub-classes which implement finite-length datasets should redefine this method. | |
208 Some methods only make sense for finite-length datasets, and will perform | |
209 assert len(dataset)<DataSet.infinity | |
210 in order to check the finiteness of the dataset. | |
211 """ | |
212 return infinity | |
213 | |
146 def hasFields(self,*fieldnames): | 214 def hasFields(self,*fieldnames): |
147 """ | 215 """ |
148 Return true if the given field name (or field names, if multiple arguments are | 216 Return true if the given field name (or field names, if multiple arguments are |
149 given) is recognized by the DataSet (i.e. can be used as a field name in one | 217 given) is recognized by the DataSet (i.e. can be used as a field name in one |
150 of the iterators). | 218 of the iterators). |
219 | |
220 The default implementation may be inefficient (O(# fields in dataset)), as it calls the fieldNames() | |
221 method. Many datasets may store their field names in a dictionary, which would allow more efficiency. | |
222 """ | |
223 return len(unique_elements_list_intersection(fieldnames,self.fieldNames()))>0 | |
224 | |
225 def fieldNames(self): | |
226 """ | |
227 Return the list of field names that are supported by the iterators, | |
228 and for which hasFields(fieldname) would return True. | |
151 """ | 229 """ |
152 raise AbstractFunction() | 230 raise AbstractFunction() |
153 | 231 |
154 | 232 def __call__(self,*fieldnames): |
155 def merge_fields(self,*specifications): | 233 """ |
156 """ | 234 Return a dataset that sees only the fields whose name are specified. |
157 Return a new dataset that maps old fields (of self) to new fields (of the returned | 235 """ |
158 dataset). The minimal syntax that should be supported is the following: | 236 assert self.hasFields(fieldnames) |
159 new_field_specifications = [new_field_spec1, new_field_spec2, ...] | 237 return self.fields(fieldnames).examples() |
160 new_field_spec = ([old_field1, old_field2, ...], new_field) | 238 |
161 In general both old_field and new_field should be strings, but some datasets may also | 239 def fields(self,*fieldnames): |
162 support additional indexing schemes within each field (e.g. column slice | 240 """ |
163 of a matrix-like field). | 241 Return a DataSetFields object associated with this dataset. |
164 """ | 242 """ |
165 raise AbstractFunction() | 243 return DataSetFields(self,fieldnames) |
166 | |
167 def merge_field_values(self,*field_value_pairs): | |
168 """ | |
169 Return the value that corresponds to merging the values of several fields, | |
170 given as arguments (field_name, field_value) pairs with self.hasField(field_name). | |
171 This may be used by implementations of merge_fields. | |
172 Raise a ValueError if the operation is not possible. | |
173 """ | |
174 fieldnames,fieldvalues = zip(*field_value_pairs) | |
175 raise ValueError("Unable to merge values of these fields:"+repr(fieldnames)) | |
176 | |
177 def examples2minibatch(self,examples): | |
178 """ | |
179 Combine a list of Examples into a minibatch. A minibatch is an Example whose fields | |
180 are iterable over the examples of the minibatch. | |
181 """ | |
182 raise AbstractFunction() | |
183 | |
184 def rename(self,rename_dict): | |
185 """ | |
186 Changes a dataset into one that renames fields, using a dictionnary that maps old field | |
187 names to new field names. The only fields visible by the returned dataset are those | |
188 whose names are keys of the rename_dict. | |
189 """ | |
190 self_class = self.__class__ | |
191 class SelfRenamingDataSet(RenamingDataSet,self_class): | |
192 pass | |
193 self.__class__ = SelfRenamingDataSet | |
194 # set the rename_dict and src fields | |
195 SelfRenamingDataSet.__init__(self,self,rename_dict) | |
196 return self | |
197 | |
198 def apply_function(self,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): | |
199 """ | |
200 Changes a dataset into one that contains as fields the results of applying | |
201 the given function (example-wise) to the specified input_fields. The | |
202 function should return a sequence whose elements will be stored in | |
203 fields whose names are given in the output_fields list. If copy_inputs | |
204 is True then the resulting dataset will also contain the fields of self. | |
205 If accept_minibatches, then the function may be called | |
206 with minibatches as arguments (what is returned by the minibatches | |
207 iterator). In any case, the computations may be delayed until the examples | |
208 of the resulting dataset are requested. If cache is True, then | |
209 once the output fields for some examples have been computed, then | |
210 are cached (to avoid recomputation if the same examples are again | |
211 requested). | |
212 """ | |
213 self_class = self.__class__ | |
214 class SelfApplyFunctionDataSet(ApplyFunctionDataSet,self_class): | |
215 pass | |
216 self.__class__ = SelfApplyFunctionDataSet | |
217 # set the required additional fields | |
218 ApplyFunctionDataSet.__init__(self,self,function, input_fields, output_fields, copy_inputs, accept_minibatches, cache) | |
219 return self | |
220 | |
221 | |
222 class FiniteLengthDataSet(DataSet): | |
223 """ | |
224 Virtual interface for datasets that have a finite length (number of examples), | |
225 and thus recognize a len(dataset) call. | |
226 """ | |
227 def __init__(self): | |
228 DataSet.__init__(self) | |
229 | |
230 def __len__(self): | |
231 """len(dataset) returns the number of examples in the dataset.""" | |
232 raise AbstractFunction() | |
233 | |
234 def __call__(self,fieldname_or_fieldnames): | |
235 """ | |
236 Extract one or more fields. This may be an expensive operation when the | |
237 dataset is large. It is not the recommanded way to access individual values | |
238 (use the iterators instead). If the argument is a string fieldname, then the result | |
239 is a sequence (iterable object) of values for that field, for the whole dataset. If the | |
240 argument is a list of field names, then the result is a 'batch', i.e., an Example with keys | |
241 corresponding to the given field names and values being iterable objects over the | |
242 individual example values. | |
243 """ | |
244 if type(fieldname_or_fieldnames) is string: | |
245 minibatch = self.minibatches([fieldname_or_fieldnames],len(self)).next() | |
246 return minibatch[fieldname_or_fieldnames] | |
247 return self.minibatches(fieldname_or_fieldnames,len(self)).next() | |
248 | |
249 class SliceableDataSet(DataSet): | |
250 """ | |
251 Virtual interface, a subclass of DataSet for datasets which are sliceable | |
252 and whose individual elements can be accessed, generally respecting the | |
253 python semantics for [spec], where spec is either a non-negative integer | |
254 (for selecting one example), a python slice(start,stop,step) for selecting a regular | |
255 sub-dataset comprising examples start,start+step,start+2*step,...,n (with n<stop), or a | |
256 sequence (e.g. a list) of integers [i1,i2,...,in] for selecting | |
257 an arbitrary subset of examples. This is useful for obtaining | |
258 sub-datasets, e.g. for splitting a dataset into training and test sets. | |
259 """ | |
260 def __init__(self): | |
261 DataSet.__init__(self) | |
262 | |
263 def minibatches(self, | |
264 fieldnames = DataSet.minibatches_fieldnames, | |
265 minibatch_size = DataSet.minibatches_minibatch_size, | |
266 n_batches = DataSet.minibatches_n_batches): | |
267 """ | |
268 If the n_batches is empty, we want to see all the examples possible | |
269 for the given minibatch_size (possibly missing a few at the end of the dataset). | |
270 """ | |
271 # substitute the defaults: | |
272 if n_batches is None: n_batches = len(self) / minibatch_size | |
273 return DataSet.Iterator(self, fieldnames, minibatch_size, n_batches) | |
274 | 244 |
275 def __getitem__(self,i): | 245 def __getitem__(self,i): |
276 """ | 246 """ |
277 dataset[i] returns the (i+1)-th example of the dataset. | 247 dataset[i] returns the (i+1)-th example of the dataset. |
278 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. | 248 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. |
279 dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. | 249 dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. |
280 dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in. | 250 dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in. |
281 """ | 251 |
282 raise AbstractFunction() | 252 Note that some stream datasets may be unable to implement slicing/indexing |
283 | 253 because they can only iterate through examples one or a minibatch at a time |
284 def __getslice__(self,*slice_args): | 254 and do not actually store or keep past (or future) examples. |
285 """ | 255 """ |
286 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. | 256 raise NotImplementedError() |
287 dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. | 257 |
288 """ | 258 def valuesHStack(self,fieldnames,fieldvalues): |
289 raise AbstractFunction() | 259 """ |
290 | 260 Return a value that corresponds to concatenating (horizontally) several field values. |
291 | 261 This can be useful to merge some fields. The implementation of this operation is likely |
292 class FiniteWidthDataSet(DataSet): | 262 to involve a copy of the original values. When the values are numpy arrays, the |
293 """ | 263 result should be numpy.hstack(values). If it makes sense, this operation should |
294 Virtual interface for datasets that have a finite width (number of fields), | 264 work as well when each value corresponds to multiple examples in a minibatch |
295 and thus return a list of fieldNames. | 265 e.g. if each value is a Ni-vector and a minibatch of length L is a LxNi matrix, |
296 """ | 266 then the result should be a Lx(N1+N2+..) matrix equal to numpy.hstack(values). |
297 def __init__(self): | 267 The default is to use numpy.hstack for numpy.ndarray values, and a list |
298 DataSet.__init__(self) | 268 pointing to the original values for other data types. |
299 | 269 """ |
300 def hasFields(self,*fields): | 270 all_numpy=True |
301 has_fields=True | 271 for value in fieldvalues: |
302 fieldnames = self.fieldNames() | 272 if not type(value) is numpy.ndarray: |
303 for name in fields: | 273 all_numpy=False |
304 if name not in fieldnames: | 274 if all_numpy: |
305 has_fields=False | 275 return numpy.hstack(fieldvalues) |
306 return has_fields | 276 # the default implementation of horizontal stacking is to put values in a list |
307 | 277 return fieldvalues |
278 | |
279 | |
280 def valuesVStack(self,fieldname,values): | |
281 """ | |
282 Return a value that corresponds to concatenating (vertically) several values of the | |
283 same field. This can be important to build a minibatch out of individual examples. This | |
284 is likely to involve a copy of the original values. When the values are numpy arrays, the | |
285 result should be numpy.vstack(values). | |
286 The default is to use numpy.vstack for numpy.ndarray values, and a list | |
287 pointing to the original values for other data types. | |
288 """ | |
289 all_numpy=True | |
290 for value in values: | |
291 if not type(value) is numpy.ndarray: | |
292 all_numpy=False | |
293 if all_numpy: | |
294 return numpy.vstack(values) | |
295 # the default implementation of vertical stacking is to put values in a list | |
296 return values | |
297 | |
298 def __or__(self,other): | |
299 """ | |
300 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of | |
301 fields of the argument datasets. This only works if they all have the same length. | |
302 """ | |
303 return HStackedDataSet(self,other) | |
304 | |
305 def __add__(self,other): | |
306 """ | |
307 dataset1 + dataset2 is a dataset that concatenates the examples from the argument datasets | |
308 (and whose length is the sum of the length of the argument datasets). This only | |
309 works if they all have the same fields. | |
310 """ | |
311 return VStackedDataSet(self,other) | |
312 | |
313 def hstack(datasets): | |
314 """ | |
315 hstack(dataset1,dataset2,...) returns dataset1 | datataset2 | ... | |
316 which is a dataset whose fields list is the concatenation of the fields | |
317 of the individual datasets. | |
318 """ | |
319 assert len(datasets)>0 | |
320 if len(datasets)==1: | |
321 return datasets[0] | |
322 return HStackedDataSet(datasets) | |
323 | |
324 def vstack(datasets): | |
325 """ | |
326 vstack(dataset1,dataset2,...) returns dataset1 + datataset2 + ... | |
327 which is a dataset which iterates first over the examples of dataset1, then | |
328 over those of dataset2, etc. | |
329 """ | |
330 assert len(datasets)>0 | |
331 if len(datasets)==1: | |
332 return datasets[0] | |
333 return VStackedDataSet(datasets) | |
334 | |
335 | |
336 class DataSetFields(LookupList): | |
337 """ | |
338 Although a DataSet iterates over examples (like rows of a matrix), an associated | |
339 DataSetFields iterates over fields (like columns of a matrix), and can be understood | |
340 as a transpose of the associated dataset. | |
341 | |
342 To iterate over fields, one can do | |
343 * for fields in dataset.fields() | |
344 * for fields in dataset(field1,field2,...).fields() to select a subset of fields | |
345 * for fields in dataset.fields(field1,field2,...) to select a subset of fields | |
346 and each of these fields is iterable over the examples: | |
347 * for field_examples in dataset.fields(): | |
348 for example_value in field_examples: | |
349 ... | |
350 but when the dataset is a stream (unbounded length), it is not recommanded to do | |
351 such things because the underlying dataset may refuse to access the different fields in | |
352 an unsynchronized ways. Hence the fields() method is illegal for streams, by default. | |
353 The result of fields() is a DataSetFields object, which iterates over fields, | |
354 and whose elements are iterable over examples. A DataSetFields object can | |
355 be turned back into a DataSet with its examples() method: | |
356 dataset2 = dataset1.fields().examples() | |
357 and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). | |
358 """ | |
359 def __init__(self,dataset,*fieldnames): | |
360 self.dataset=dataset | |
361 assert dataset.hasField(*fieldnames) | |
362 LookupList.__init__(self,dataset.fieldNames(), | |
363 dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),minibatch_size=len(dataset)).next() | |
364 def examples(self): | |
365 return self.dataset | |
366 | |
367 def __or__(self,other): | |
368 """ | |
369 fields1 | fields2 is a DataSetFields that whose list of examples is the concatenation | |
370 of the list of examples of DataSetFields fields1 and fields2. | |
371 """ | |
372 return (self.examples() + other.examples()).fields() | |
373 | |
374 def __add__(self,other): | |
375 """ | |
376 fields1 + fields2 is a DataSetFields that whose list of fields is the concatenation | |
377 of the fields of DataSetFields fields1 and fields2. | |
378 """ | |
379 return (self.examples() | other.examples()).fields() | |
380 | |
381 class MinibatchDataSet(DataSet): | |
382 """ | |
383 Turn a LookupList of same-length fields into an example-iterable dataset. | |
384 Each element of the lookup-list should be an iterable and sliceable, all of the same length. | |
385 """ | |
386 def __init__(self,fields_lookuplist,values_vstack=DataSet().valuesVStack, | |
387 values_hstack=DataSet().valuesHStack): | |
388 """ | |
389 The user can (and generally should) also provide values_vstack(fieldname,fieldvalues) | |
390 and a values_hstack(fieldnames,fieldvalues) functions behaving with the same | |
391 semantics as the DataSet methods of the same name (but without the self argument). | |
392 """ | |
393 self.fields=fields_lookuplist | |
394 assert len(fields_lookuplist)>0 | |
395 self.length=len(fields_lookuplist[0]) | |
396 for field in fields_lookuplist[1:]: | |
397 assert self.length==len(field) | |
398 self.values_vstack=values_vstack | |
399 self.values_hstack=values_hstack | |
400 | |
401 def __len__(self): | |
402 return self.length | |
403 | |
404 def __getitem__(self,i): | |
405 return Example(self.fields.keys(),[field[i] for field in self.fields]) | |
406 | |
308 def fieldNames(self): | 407 def fieldNames(self): |
309 """Return the list of field names that are supported by the iterators, | 408 return self.fields.keys() |
310 and for which hasFields(fieldname) would return True.""" | 409 |
311 raise AbstractFunction() | 410 def hasField(self,*fieldnames): |
312 | 411 for fieldname in fieldnames: |
313 | 412 if fieldname not in self.fields: |
314 class RenamingDataSet(FiniteWidthDataSet): | 413 return False |
315 """A DataSet that wraps another one, and makes it look like the field names | 414 return True |
316 are different | 415 |
317 | |
318 Renaming is done by a dictionary that maps new names to the old ones used in | |
319 self.src. | |
320 """ | |
321 def __init__(self, src, rename_dct): | |
322 DataSet.__init__(self) | |
323 self.src = src | |
324 self.rename_dct = copy.copy(rename_dct) | |
325 | |
326 def fieldNames(self): | |
327 return self.rename_dct.keys() | |
328 | |
329 def minibatches(self, | 416 def minibatches(self, |
330 fieldnames = DataSet.minibatches_fieldnames, | 417 fieldnames = minibatches_fieldnames, |
331 minibatch_size = DataSet.minibatches_minibatch_size, | 418 minibatch_size = minibatches_minibatch_size, |
332 n_batches = DataSet.minibatches_n_batches): | 419 n_batches = minibatches_n_batches): |
333 dct = self.rename_dct | 420 class Iterator(object): |
334 new_fieldnames = [dct.get(f, f) for f in fieldnames] | 421 def __init__(self,ds): |
335 return self.src.minibatches(new_fieldnames, minibatches_size, n_batches) | 422 self.ds=ds |
336 | 423 self.next_example=0 |
337 | 424 self.n_batches_done=0 |
338 # we may want ArrayDataSet defined in another python file | 425 assert minibatch_size > 0 |
339 | 426 if minibatch_size > ds.length |
340 import numpy | 427 raise NotImplementedError() |
341 | |
342 def as_array_dataset(dataset): | |
343 # Generally datasets can be efficient by making data fields overlap, but | |
344 # this function doesn't know which fields overlap. So, it should check if | |
345 # dataset supports an as_array_dataset member function, and return that if | |
346 # possible. | |
347 if hasattr(dataset, 'as_array_dataset'): | |
348 return dataset.as_array_dataset() | |
349 | |
350 raise NotImplementedError | |
351 | |
352 # Make ONE big minibatch with all the examples, to separate the fields. | |
353 n_examples = len(dataset) | |
354 batch = dataset.minibatches( minibatch_size = len(dataset)).next() | |
355 | |
356 # Each field of the underlying dataset must be convertible to a numpy array of the same type | |
357 # currently just double, but should use the smallest compatible dtype | |
358 n_fields = len(batch) | |
359 fieldnames = batch.fields.keys() | |
360 total_width = 0 | |
361 type = None | |
362 fields = LookupList() | |
363 for i in xrange(n_fields): | |
364 field = array(batch[i]) | |
365 assert field.shape[0]==n_examples | |
366 width = field.shape[1] | |
367 start=total_width | |
368 total_width += width | |
369 fields[fieldnames[i]]=slice(start,total_width,1) | |
370 # many complicated things remain to be done: | |
371 # - find common dtype | |
372 # - decide what to do with extra dimensions if not the same in all fields | |
373 # - try to see if we can avoid the copy? | |
374 | |
375 class ArrayDataSet(FiniteLengthDataSet,FiniteWidthDataSet,SliceableDataSet): | |
376 """ | |
377 An ArrayDataSet behaves like a numpy array but adds the notion of named fields | |
378 from DataSet (and the ability to view the values of multiple fields as an 'Example'). | |
379 It is a fixed-length and fixed-width dataset | |
380 in which each element is a fixed dimension numpy array or a number, hence the whole | |
381 dataset corresponds to a numpy array. Fields | |
382 must correspond to a slice of array columns or to a list of column numbers. | |
383 If the dataset has fields, | |
384 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array. | |
385 Any dataset can also be converted to a numpy array (losing the notion of fields | |
386 by the numpy.array(dataset) call. | |
387 """ | |
388 | |
389 class Iterator(LookupList): | |
390 """An iterator over a finite dataset that implements wrap-around""" | |
391 def __init__(self, dataset, fieldnames, minibatch_size, next_max): | |
392 if fieldnames is None: fieldnames = dataset.fieldNames() | |
393 LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) | |
394 self.dataset=dataset | |
395 self.minibatch_size=minibatch_size | |
396 self.next_count = 0 | |
397 self.next_max = next_max | |
398 self.current = -self.minibatch_size | |
399 assert minibatch_size > 0 | |
400 if minibatch_size >= len(dataset): | |
401 raise NotImplementedError() | |
402 | |
403 def __iter__(self): #makes for loop work | |
404 return self | |
405 | |
406 @staticmethod | |
407 def matcat(a, b): | |
408 a0, a1 = a.shape | |
409 b0, b1 = b.shape | |
410 assert a1 == b1 | |
411 assert a.dtype is b.dtype | |
412 rval = numpy.empty( (a0 + b0, a1), dtype=a.dtype) | |
413 rval[:a0,:] = a | |
414 rval[a0:,:] = b | |
415 return rval | |
416 | |
417 def next_index(self): | |
418 n_rows = self.dataset.data.shape[0] | |
419 next_i = self.current+self.minibatch_size | |
420 if next_i >= n_rows: | |
421 next_i -= n_rows | |
422 return next_i | |
423 | |
424 def next(self): | |
425 | |
426 #check for end-of-loop | |
427 self.next_count += 1 | |
428 if self.next_count == self.next_max: | |
429 raise StopIteration | |
430 | |
431 #determine the first and last elements of the minibatch slice we'll return | |
432 n_rows = self.dataset.data.shape[0] | |
433 self.current = self.next_index() | |
434 upper = self.current + self.minibatch_size | |
435 | |
436 data = self.dataset.data | |
437 | |
438 if upper <= n_rows: | |
439 #this is the easy case, we only need once slice | |
440 dataview = data[self.current:upper] | |
441 else: | |
442 # the minibatch wraps around the end of the dataset | |
443 dataview = data[self.current:] | |
444 upper -= n_rows | |
445 assert upper > 0 | |
446 dataview = self.matcat(dataview, data[:upper]) | |
447 | |
448 self._values = [dataview[:, self.dataset.fields[f]]\ | |
449 for f in self._names] | |
450 return self | |
451 | |
452 | |
453 def __init__(self, data, fields=None): | |
454 """ | |
455 There are two ways to construct an ArrayDataSet: (1) from an | |
456 existing dataset (which may result in a copy of the data in a numpy array), | |
457 or (2) from a numpy.array (the data argument), along with an optional description | |
458 of the fields (a LookupList of column slices (or column lists) indexed by field names). | |
459 """ | |
460 self.data=data | |
461 self.fields=fields | |
462 rows, cols = data.shape | |
463 | |
464 if fields: | |
465 for fieldname,fieldslice in fields.items(): | |
466 assert type(fieldslice) is int or isinstance(fieldslice,slice) or hasattr(fieldslice,"__iter__") | |
467 if hasattr(fieldslice,"__iter__"): # is a sequence | |
468 for i in fieldslice: | |
469 assert type(i) is int | |
470 elif isinstance(fieldslice,slice): | |
471 # make sure fieldslice.start and fieldslice.step are defined | |
472 start=fieldslice.start | |
473 step=fieldslice.step | |
474 if not start: | |
475 start=0 | |
476 if not step: | |
477 step=1 | |
478 if not fieldslice.start or not fieldslice.step: | |
479 fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step) | |
480 # and coherent with the data array | |
481 assert fieldslice.start >= 0 and fieldslice.stop <= cols | |
482 | |
483 def minibatches(self, | |
484 fieldnames = DataSet.minibatches_fieldnames, | |
485 minibatch_size = DataSet.minibatches_minibatch_size, | |
486 n_batches = DataSet.minibatches_n_batches): | |
487 """ | |
488 If the fieldnames list is None, it means that we want to see ALL the fields. | |
489 | |
490 If the n_batches is None, we want to see all the examples possible | |
491 for the given minibatch_size (possibly missing some near the end). | |
492 """ | |
493 # substitute the defaults: | |
494 if n_batches is None: n_batches = len(self) / minibatch_size | |
495 return ArrayDataSet.Iterator(self, fieldnames, minibatch_size, n_batches) | |
496 | |
497 def fieldNames(self): | |
498 """Return the list of field names that are supported by getattr and hasField.""" | |
499 return self.fields.keys() | |
500 | |
501 def __len__(self): | |
502 """len(dataset) returns the number of examples in the dataset.""" | |
503 return len(self.data) | |
504 | |
505 def __getitem__(self,i): | |
506 """ | |
507 dataset[i] returns the (i+1)-th Example of the dataset. | |
508 If there are no fields the result is just a numpy array (for the i-th row of the dataset data matrix). | |
509 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. | |
510 dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. | |
511 dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in. | |
512 """ | |
513 if self.fields: | |
514 fieldnames,fieldslices=zip(*self.fields.items()) | |
515 return Example(self.fields.keys(),[self.data[i,fieldslice] for fieldslice in self.fields.values()]) | |
516 else: | |
517 return self.data[i] | |
518 | |
519 def __getslice__(self,*args): | |
520 """ | |
521 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. | |
522 dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. | |
523 """ | |
524 return ArrayDataSet(self.data.__getslice__(*args), fields=self.fields) | |
525 | |
526 def indices_of_unique_columns_used(self): | |
527 """ | |
528 Return the unique indices of the columns actually used by the fields, and a boolean | |
529 that signals (if True) that used columns overlap. If they do then the | |
530 indices are not repeated in the result. | |
531 """ | |
532 columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) | |
533 overlapping_columns = False | |
534 for field_slice in self.fields.values(): | |
535 if sum(columns_used[field_slice])>0: overlapping_columns=True | |
536 columns_used[field_slice]=True | |
537 return [i for i,used in enumerate(columns_used) if used],overlapping_columns | |
538 | |
539 def slice_of_unique_columns_used(self): | |
540 """ | |
541 Return None if the indices_of_unique_columns_used do not form a slice. If they do, | |
542 return that slice. It means that the columns used can be extracted | |
543 from the data array without making a copy. If the fields overlap | |
544 but their unique columns used form a slice, still return that slice. | |
545 """ | |
546 columns_used,overlapping_columns = self.indices_of_columns_used() | |
547 mappable_to_one_slice = True | |
548 if not overlapping_fields: | |
549 start=0 | |
550 while start<len(columns_used) and not columns_used[start]: | |
551 start+=1 | |
552 stop=len(columns_used) | |
553 while stop>0 and not columns_used[stop-1]: | |
554 stop-=1 | |
555 step=0 | |
556 i=start | |
557 while i<stop: | |
558 j=i+1 | |
559 while j<stop and not columns_used[j]: | |
560 j+=1 | |
561 if step: | |
562 if step!=j-i: | |
563 mappable_to_one_slice = False | |
564 break | |
565 else: | |
566 step = j-i | |
567 i=j | |
568 return slice(start,stop,step) | |
569 | |
570 class ApplyFunctionDataSet(FiniteWidthDataSet): | |
571 """ | |
572 A dataset that contains as fields the results of applying | |
573 a given function (example-wise) to specified input_fields of a source | |
574 dataset. The function should return a sequence whose elements will be stored in | |
575 fields whose names are given in the output_fields list. If copy_inputs | |
576 is True then the resulting dataset will also contain the fields of the source. | |
577 dataset. If accept_minibatches, then the function expects | |
578 minibatches as arguments (what is returned by the minibatches | |
579 iterator). In any case, the computations may be delayed until the examples | |
580 of self are requested. If cache is True, then | |
581 once the output fields for some examples have been computed, then | |
582 are cached (to avoid recomputation if the same examples are again requested). | |
583 """ | |
584 def __init__(src,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True, compute_now=False): | |
585 DataSet.__init__(self) | |
586 self.src=src | |
587 self.function=function | |
588 assert src.hasFields(input_fields) | |
589 self.input_fields=input_fields | |
590 self.output_fields=output_fields | |
591 assert not (copy_inputs and compute_now and not hasattr(src,'fieldNames')) | |
592 self.copy_inputs=copy_inputs | |
593 self.accept_minibatches=accept_minibatches | |
594 self.cache=cache | |
595 self.compute_now=compute_now | |
596 if compute_now: | |
597 assert hasattr(src,'__len__') and len(src)>=0 | |
598 fieldnames = output_fields | |
599 if copy_inputs: fieldnames = src.fieldNames() + output_fields | |
600 if accept_minibatches: | |
601 # make a single minibatch with all the inputs | |
602 inputs = src.minibatches(input_fields,len(src)).next() | |
603 # and apply the function to it, and transpose into a list of examples (field values, actually) | |
604 self.cached_examples = zip(*Example(output_fields,function(*inputs))) | |
605 else: | |
606 # compute a list with one tuple per example, with the function outputs | |
607 self.cached_examples = [ function(input) for input in src.zip(input_fields) ] | |
608 elif cache: | |
609 # maybe a fixed-size array kind of structure would be more efficient than a list | |
610 # in the case where src is FiniteDataSet. -YB | |
611 self.cached_examples = [] | |
612 | |
613 def fieldNames(self): | |
614 if self.copy_inputs: | |
615 return self.output_fields + self.src.fieldNames() | |
616 return self.output_fields | |
617 | |
618 def minibatches(self, | |
619 fieldnames = DataSet.minibatches_fieldnames, | |
620 minibatch_size = DataSet.minibatches_minibatch_size, | |
621 n_batches = DataSet.minibatches_n_batches): | |
622 | |
623 class Iterator(LookupList): | |
624 | |
625 def __init__(self,dataset): | |
626 if fieldnames is None: | |
627 assert hasattr(dataset,"fieldNames") | |
628 fieldnames = dataset.fieldNames() | |
629 self.example_index=0 | |
630 LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) | |
631 self.dataset=dataset | |
632 self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(dataset.input_fields))), | |
633 minibatch_size,n_batches) | |
634 self.fieldnames_not_in_input = [] | |
635 if self.copy_inputs: | |
636 self.fieldnames_not_in_input = filter(lambda x: not x in dataset.input_fields, fieldnames) | |
637 | |
638 def __iter__(self): | 428 def __iter__(self): |
639 return self | 429 return self |
640 | |
641 def next_index(self): | 430 def next_index(self): |
642 return self.src_iterator.next_index() | 431 return self.next_example |
432 def next(self): | |
433 upper = next_example+minibatch_size | |
434 if upper<=self.ds.length: | |
435 minibatch = Example(self.ds.fields.keys(), | |
436 [field[next_example:upper] | |
437 for field in self.ds.fields]) | |
438 else: # we must concatenate (vstack) the bottom and top parts of our minibatch | |
439 minibatch = Example(self.ds.fields.keys(), | |
440 [self.ds.valuesVStack(name,[value[next_example:], | |
441 value[0:upper-self.ds.length]]) | |
442 for name,value in self.ds.fields.items()]) | |
443 self.next_example+=minibatch_size | |
444 self.n_batches_done+=1 | |
445 if n_batches: | |
446 if self.n_batches_done==n_batches: | |
447 raise StopIteration | |
448 if self.next_example>=self.ds.length: | |
449 self.next_example-=self.ds.length | |
450 else: | |
451 if self.next_example>=self.ds.length: | |
452 raise StopIteration | |
453 return DataSetFields(MinibatchDataSet(minibatch),fieldnames) | |
454 | |
455 return Iterator(self) | |
456 | |
457 def valuesVStack(self,fieldname,fieldvalues): | |
458 return self.values_vstack(fieldname,fieldvalues) | |
459 | |
460 def valuesHStack(self,fieldnames,fieldvalues): | |
461 return self.values_hstack(fieldnames,fieldvalues) | |
462 | |
463 class HStackedDataSet(DataSet): | |
464 """ | |
465 A DataSet that wraps several datasets and shows a view that includes all their fields, | |
466 i.e. whose list of fields is the concatenation of their lists of fields. | |
467 | |
468 If a field name is found in more than one of the datasets, then either an error is | |
469 raised or the fields are renamed (either by prefixing the __name__ attribute | |
470 of the dataset + ".", if it exists, or by suffixing the dataset index in the argument list). | |
471 | |
472 TODO: automatically detect a chain of stacked datasets due to A | B | C | D ... | |
473 """ | |
474 def __init__(self,datasets,accept_nonunique_names=False): | |
475 DataSet.__init__(self) | |
476 self.datasets=datasets | |
477 self.accept_nonunique_names=accept_nonunique_names | |
478 self.fieldname2dataset={} | |
479 | |
480 def rename_field(fieldname,dataset,i): | |
481 if hasattr(dataset,"__name__"): | |
482 return dataset.__name__ + "." + fieldname | |
483 return fieldname+"."+str(i) | |
643 | 484 |
485 # make sure all datasets have the same length and unique field names | |
486 self.length=None | |
487 names_to_change=[] | |
488 for i in xrange(len(datasets)): | |
489 dataset = datasets[i] | |
490 length=len(dataset) | |
491 if self.length: | |
492 assert self.length==length | |
493 else: | |
494 self.length=length | |
495 for fieldname in dataset.fieldNames(): | |
496 if fieldname in self.fieldname2dataset: # name conflict! | |
497 if accept_nonunique_names: | |
498 fieldname=rename_field(fieldname,dataset,i) | |
499 names2change.append((fieldname,i)) | |
500 else: | |
501 raise ValueError("Incompatible datasets: non-unique field name = "+fieldname) | |
502 self.fieldname2dataset[fieldname]=i | |
503 for fieldname,i in names_to_change: | |
504 del self.fieldname2dataset[fieldname] | |
505 self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i | |
506 | |
507 def hasField(self,*fieldnames): | |
508 for fieldname in fieldnames: | |
509 if not fieldname in self.fieldname2dataset: | |
510 return False | |
511 return True | |
512 | |
513 def fieldNames(self): | |
514 return self.fieldname2dataset.keys() | |
515 | |
516 def minibatches(self, | |
517 fieldnames = minibatches_fieldnames, | |
518 minibatch_size = minibatches_minibatch_size, | |
519 n_batches = minibatches_n_batches): | |
520 | |
521 class Iterator(object): | |
522 def __init__(self,hsds,iterators): | |
523 self.hsds=hsds | |
524 self.iterators=iterators | |
525 def __iter__(self): | |
526 return self | |
527 def next_index(self): | |
528 return self.iterators[0].next_index() | |
644 def next(self): | 529 def next(self): |
645 example_index = self.src_iterator.next_index() | 530 # concatenate all the fields of the minibatches |
646 src_examples = self.src_iterator.next() | 531 minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) |
647 if self.dataset.copy_inputs: | 532 # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch |
648 function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields] | 533 return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack, |
649 else: | 534 self.hsds.valuesHStack), |
650 function_inputs = src_examples | 535 fieldnames if fieldnames else hsds.fieldNames()) |
651 if self.dataset.cached_examples: | 536 |
652 cache_len=len(self.cached_examples) | 537 assert self.hasfields(fieldnames) |
653 if example_index<cache_len+minibatch_size: | 538 # find out which underlying datasets are necessary to service the required fields |
654 outputs_list = self.cached_examples[example_index:example_index+minibatch_size] | 539 # and construct corresponding minibatch iterators |
655 # convert the minibatch list of examples | 540 if fieldnames: |
656 # into a list of fields each of which iterate over the minibatch | 541 datasets=set([]) |
657 outputs = zip(*outputs_list) | 542 fields_in_dataset=dict([(dataset,[]) for dataset in datasets]) |
658 else: | 543 for fieldname in fieldnames: |
659 outputs = self.dataset.function(*function_inputs) | 544 dataset=self.datasets[self.fieldnames2dataset[fieldname]] |
660 if self.dataset.cache: | 545 datasets.add(dataset) |
661 # convert the list of fields, each of which can iterate over the minibatch | 546 fields_in_dataset[dataset].append(fieldname) |
662 # into a list of examples in the minibatch (each of which is a list of field values) | 547 datasets=list(datasets) |
663 outputs_list = zip(*outputs) | 548 iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches) |
664 # copy the outputs_list into the cache | 549 for dataset in datasets] |
665 for i in xrange(cache_len,example_index): | 550 else: |
666 self.cached_examples.append(None) | 551 datasets=self.datasets |
667 self.cached_examples += outputs_list | 552 iterators=[dataset.minibatches(None,minibatch_size,n_batches) for dataset in datasets] |
668 else: | 553 return Iterator(self,iterators) |
669 outputs = self.dataset.function(*function_inputs) | 554 |
670 | 555 |
671 return Example(self.fieldnames_not_in_input+self.dataset.output_fields, | 556 def valuesVStack(self,fieldname,fieldvalues): |
672 [src_examples[field_name] for field_name in self.fieldnames_not_in_input]+outputs) | 557 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) |
673 | 558 |
674 | 559 def valuesHStack(self,fieldnames,fieldvalues): |
675 for fieldname in fieldnames: | 560 """ |
676 assert fieldname in self.output_fields or self.src.hasFields(fieldname) | 561 We will use the sub-dataset associated with the first fieldname in the fieldnames list |
677 return Iterator(self) | 562 to do the work, hoping that it can cope with the other values (i.e. won't care |
678 | 563 about the incompatible fieldnames). Hence this heuristic will always work if |
679 | 564 all the fieldnames are of the same sub-dataset. |
565 """ | |
566 return self.datasets[self.fieldname2dataset[fieldnames[0]]].valuesHStack(fieldnames,fieldvalues) | |
567 | |
568 class VStackedDataSet(DataSet): | |
569 """ | |
570 A DataSet that wraps several datasets and shows a view that includes all their examples, | |
571 in the order provided. This clearly assumes that they all have the same field names | |
572 and all (except possibly the last one) are of finite length. | |
573 | |
574 TODO: automatically detect a chain of stacked datasets due to A + B + C + D ... | |
575 """ | |
576 def __init__(self,datasets): | |
577 self.datasets=datasets | |
578 self.length=0 | |
579 self.index2dataset={} | |
580 # we use this map from row index to dataset index for constant-time random access of examples, | |
581 # to avoid having to search for the appropriate dataset each time and slice is asked for | |
582 for dataset,k in enumerate(datasets[0:-1]): | |
583 L=len(dataset) | |
584 assert L<DataSet.infinity | |
585 for i in xrange(L): | |
586 self.index2dataset[self.length+i]=k | |
587 self.length+=L | |
588 self.last_start=self.length | |
589 self.length+=len(datasets[-1]) | |
590 | |
591 | |
680 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 592 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
681 """ | 593 """ |
682 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the | 594 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the |
683 user to define a set of fields as the 'input' field and a set of fields | 595 user to define a set of fields as the 'input' field and a set of fields |
684 as the 'target' field. Optionally, a single weight_field can also be defined. | 596 as the 'target' field. Optionally, a single weight_field can also be defined. |
685 """ | 597 """ |
686 args = ((input_fields,'input'),(output_fields,'target')) | 598 args = ((input_fields,'input'),(output_fields,'target')) |
687 if weight_field: args+=(([weight_field],'weight')) | 599 if weight_field: args+=(([weight_field],'weight')) |
688 return src_dataset.rename(*args) | 600 return src_dataset.merge_fields(*args) |
689 | 601 |
690 | 602 |
691 | 603 |
692 | 604 |