comparison dataset.py @ 22:b6b36f65664f

Created virtual sub-classes of DataSet: {Finite{Length,Width},Sliceable}DataSet, removed .field ability from LookupList (because of setattr problems), removed fieldNames() from DataSet (but is in FiniteWidthDataSet, where it makes sense), and added hasFields() instead. Fixed problems in asarray, and tested previous functionality in _test_dataset.py, but not yet new functionality.
author bengioy@esprit.iro.umontreal.ca
date Mon, 07 Apr 2008 20:44:37 -0400
parents 266c68cb6136
children 526e192b0699
comparison
equal deleted inserted replaced
21:fdf0abc490f7 22:b6b36f65664f
8 """A virtual base class for datasets. 8 """A virtual base class for datasets.
9 9
10 A DataSet is a generator of iterators; these iterators can run through the 10 A DataSet is a generator of iterators; these iterators can run through the
11 examples in a variety of ways. A DataSet need not necessarily have a finite 11 examples in a variety of ways. A DataSet need not necessarily have a finite
12 or known length, so this class can be used to interface to a 'stream' which 12 or known length, so this class can be used to interface to a 'stream' which
13 feeds on-line learning. 13 feeds on-line learning.
14 14
15 To iterate over examples, there are several possibilities: 15 To iterate over examples, there are several possibilities:
16 - for example in dataset.zip([field1, field2,field3, ...]) 16 - for example in dataset.zip([field1, field2,field3, ...])
17 - for val1,val2,val3 in dataset.zip([field1, field2,field3]) 17 - for val1,val2,val3 in dataset.zip([field1, field2,field3])
18 - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N) 18 - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N)
19 - for example in dataset 19 - for example in dataset
20 Each of these is documented below. 20 Each of these is documented below.
21 21
22 Note: For a dataset of fixed and known length, which can implement item
23 random-access efficiently (e.g. indexing and slicing), and which can profit
24 from the FiniteDataSetIterator, consider using base class FiniteDataSet.
25
26 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. 22 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content.
27 23
28 Note: The content of a field can be of any type. 24 Note: The content of a field can be of any type.
29 25
26 Note: A dataset can recognize a potentially infinite number of field names (i.e. the field
27 values can be computed on-demand, when particular field names are used in one of the
28 iterators).
29
30 Datasets of finite length should be sub-classes of FiniteLengthDataSet.
31
32 Datasets whose elements can be indexed and sub-datasets of consecutive
33 examples (i.e. slices) can be extracted from should be sub-classes of
34 SliceableDataSet.
35
36 Datasets with a finite number of fields should be sub-classes of
37 FiniteWidthDataSet.
30 """ 38 """
31 39
32 def __init__(self): 40 def __init__(self):
33 pass 41 pass
34 42
43 class Iter(LookupList):
44 def __init__(self, ll):
45 LookupList.__init__(self, ll.keys(), ll.values())
46 self.ll = ll
47 def __iter__(self): #makes for loop work
48 return self
49 def next(self):
50 self.ll.next()
51 self._values = [v[0] for v in self.ll._values]
52 return self
53
35 def __iter__(self): 54 def __iter__(self):
36 """Supports the syntax "for i in dataset: ..." 55 """Supports the syntax "for i in dataset: ..."
37 56
38 Using this syntax, "i" will be an Example instance (or equivalent) with 57 Using this syntax, "i" will be an Example instance (or equivalent) with
39 all the fields of DataSet self. Every field of "i" will give access to 58 all the fields of DataSet self. Every field of "i" will give access to
40 a field of a single example. Fields should be accessible via 59 a field of a single example. Fields should be accessible via
41 i["fielname"] or i[3] (in the fieldNames() order), but the derived class is free 60 i["fielname"] or i[3] (in the order defined by the elements of the
61 Example returned by this iterator), but the derived class is free
42 to accept any type of identifier, and add extra functionality to the iterator. 62 to accept any type of identifier, and add extra functionality to the iterator.
43 """ 63 """
44 return self.zip(*self.fieldNames()) 64 return DataSet.Iter(self.minibatches(None, minibatch_size = 1))
45 65
46 def zip(self, *fieldnames): 66 def zip(self, *fieldnames):
47 """ 67 """
48 Supports two forms of syntax: 68 Supports two forms of syntax:
49 69
59 f1, f2, and f3 fields of a single example on each loop iteration. 79 f1, f2, and f3 fields of a single example on each loop iteration.
60 80
61 The derived class may accept fieldname arguments of any type. 81 The derived class may accept fieldname arguments of any type.
62 82
63 """ 83 """
64 class Iter(LookupList): 84 return DataSet.Iter(self.minibatches(fieldnames, minibatch_size = 1))
65 def __init__(self, ll):
66 LookupList.__init__(self, ll.keys(), ll.values())
67 self.ll = ll
68 def __iter__(self): #makes for loop work
69 return self
70 def next(self):
71 self.ll.next()
72 self._values = [v[0] for v in self.ll._values]
73 return self
74 return Iter(self.minibatches(fieldnames, minibatch_size = 1))
75 85
76 minibatches_fieldnames = None 86 minibatches_fieldnames = None
77 minibatches_minibatch_size = 1 87 minibatches_minibatch_size = 1
78 minibatches_n_batches = None 88 minibatches_n_batches = None
79 def minibatches(self, 89 def minibatches(self,
80 fieldnames = minibatches_fieldnames, 90 fieldnames = minibatches_fieldnames,
81 minibatch_size = minibatches_minibatch_size, 91 minibatch_size = minibatches_minibatch_size,
82 n_batches = minibatches_n_batches): 92 n_batches = minibatches_n_batches):
83 """ 93 """
84 Supports two forms of syntax: 94 Supports three forms of syntax:
95
96 for i in dataset.minibatches(None,**kwargs): ...
85 97
86 for i in dataset.minibatches([f1, f2, f3],**kwargs): ... 98 for i in dataset.minibatches([f1, f2, f3],**kwargs): ...
87 99
88 for i1, i2, i3 in dataset.minibatches([f1, f2, f3],**kwargs): ... 100 for i1, i2, i3 in dataset.minibatches([f1, f2, f3],**kwargs): ...
89 101
90 Using the first syntax, "i" will be an indexable object, such as a list, 102 Using the first two syntaxes, "i" will be an indexable object, such as a list,
91 tuple, or Example instance, such that on every iteration, i[0] is a 103 tuple, or Example instance. In both cases, i[k] is a list-like container
104 of a batch of current examples. In the second case, i[0] is
92 list-like container of the f1 field of a batch current examples, i[1] is 105 list-like container of the f1 field of a batch current examples, i[1] is
93 a list-like container of the f2 field, etc. 106 a list-like container of the f2 field, etc.
94 107
95 Using the second syntax, i1, i2, i3 will be list-like containers of the 108 Using the first syntax, all the fields will be returned in "i".
109 Beware that some datasets may not support this syntax, if the number
110 of fields is infinite (i.e. field values may be computed "on demand").
111
112 Using the third syntax, i1, i2, i3 will be list-like containers of the
96 f1, f2, and f3 fields of a batch of examples on each loop iteration. 113 f1, f2, and f3 fields of a batch of examples on each loop iteration.
97 114
98 PARAMETERS 115 PARAMETERS
99 - fieldnames (list of any type, default None): 116 - fieldnames (list of any type, default None):
100 The loop variables i1, i2, i3 (in the example above) should contain the 117 The loop variables i1, i2, i3 (in the example above) should contain the
113 Note: A list-like container is something like a tuple, list, numpy.ndarray or 130 Note: A list-like container is something like a tuple, list, numpy.ndarray or
114 any other object that supports integer indexing and slicing. 131 any other object that supports integer indexing and slicing.
115 132
116 """ 133 """
117 raise AbstractFunction() 134 raise AbstractFunction()
118 135
119 def fieldNames(self): 136 def hasFields(*fieldnames):
120 #Yoshua- 137 """
121 # This list may not be finite; what would make sense in the use you have 138 Return true if the given field name (or field names, if multiple arguments are
122 # in mind? 139 given) is recognized by the DataSet (i.e. can be used as a field name in one
123 # -JB 140 of the iterators).
124 #James- 141 """
125 # You are right. I had put this to be able to iterate over the fields 142 raise AbstractFunction()
126 # but maybe an iterator mechanism (over fields rather than examples) 143
127 # would be more appropriate. Fieldnames are needed in general
128 # by the iterators over examples or minibatches, to construct
129 # examples or minibatches with the corresponding names as attributes.
130 # -YB
131 """
132 Return an iterator (an object with an __iter__ method) that
133 iterates over the names of the fields. As a special cases,
134 a list or a tuple of field names can be returned.
135 """"
136 # Note that some datasets
137 # may have virtual fields and support a virtually infinite number
138 # of possible field names. In that case, fieldNames() should
139 # either raise an error or iterate over a particular set of
140 # names as appropriate. Another option would be to iterate
141 # over the sub-datasets comprising a single field at a time.
142 # I am not sure yet what is most appropriate.
143 # -YB
144 """
145 raise AbstractFunction()
146
147 def rename(*new_field_specifications): 144 def rename(*new_field_specifications):
148 #Yoshua- 145 #Yoshua-
149 # Do you mean for this to be a virtual method? 146 # Do you mean for this to be a virtual method?
150 # Wouldn't this functionality be easier to provide via a 147 # Wouldn't this functionality be easier to provide via a
151 # RenamingDataSet, such as the one I've written below? 148 # RenamingDataSet, such as the one I've written below?
163 of a matrix-like field). 160 of a matrix-like field).
164 """ 161 """
165 raise AbstractFunction() 162 raise AbstractFunction()
166 163
167 164
168 def apply_function(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): 165 def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True):
169 """ 166 """
170 Return a dataset that contains as fields the results of applying 167 Return a dataset that contains as fields the results of applying
171 the given function (example-wise) to the specified input_fields. The 168 the given function (example-wise) to the specified input_fields. The
172 function should return a sequence whose elements will be stored in 169 function should return a sequence whose elements will be stored in
173 fields whose names are given in the output_fields list. If copy_inputs 170 fields whose names are given in the output_fields list. If copy_inputs
200 n_batches = DataSet.minibatches_n_batches): 197 n_batches = DataSet.minibatches_n_batches):
201 dct = self.rename_dct 198 dct = self.rename_dct
202 new_fieldnames = [dct.get(f, f) for f in fieldnames] 199 new_fieldnames = [dct.get(f, f) for f in fieldnames]
203 return self.src.minibatches(new_fieldnames, minibatches_size, n_batches) 200 return self.src.minibatches(new_fieldnames, minibatches_size, n_batches)
204 201
205 def fieldNames(self): 202 class FiniteLengthDataSet(DataSet):
206 return [dct.get(f, f) for f in self.src.fieldNames()] 203 """
207 204 Virtual interface for datasets that have a finite length (number of examples),
208 205 and thus recognize a len(dataset) call.
209 class FiniteDataSet(DataSet): 206 """
210 """ 207 def __init__(self):
211 Virtual interface, a subclass of DataSet for datasets which have a finite, known length. 208 DataSet.__init__(self)
212 Examples are indexed by an integer between 0 and self.length()-1, 209
213 and a subdataset can be obtained by slicing. This may not be appropriate in general 210 def __len__(self):
214 but only for datasets which can be thought of like ones that access rows AND fields 211 """len(dataset) returns the number of examples in the dataset."""
215 in an efficient random access way. Users are encouraged to expect only the generic dataset 212 raise AbstractFunction()
216 interface in general. A FiniteDataSet is mainly useful when one has to obtain 213
217 a subset of examples (e.g. for splitting a dataset into training and test sets). 214
218 """ 215 class SliceableDataSet(DataSet):
219 216 """
220 class FiniteDataSetIterator(object): 217 Virtual interface, a subclass of DataSet for datasets which are sliceable
221 """ 218 and whose individual elements can be accessed, generally respecting the
222 If the fieldnames list is empty, it means that we want to see ALL the fields. 219 python semantics for [spec], where spec is either a non-negative integer
223 """ 220 (for selecting one example), or a python slice (for selecting a sub-dataset
224 def __init__(self,dataset,minibatch_size=1,fieldnames=[]): 221 comprising the specified examples). This is useful for obtaining
225 self.dataset=dataset 222 sub-datasets, e.g. for splitting a dataset into training and test sets.
226 self.minibatch_size=minibatch_size 223 """
227 assert minibatch_size>=1 and minibatch_size<=len(dataset) 224 def __init__(self):
228 self.current = -self.minibatch_size 225 DataSet.__init__(self)
229 self.fieldnames = fieldnames
230 if len(dataset) % minibatch_size:
231 raise NotImplementedError()
232
233 def __iter__(self):
234 return self
235 226
236 def next(self):
237 self.current+=self.minibatch_size
238 if self.current>=len(self.dataset):
239 self.current=-self.minibatch_size
240 raise StopIteration
241 if self.minibatch_size==1:
242 complete_example=self.dataset[self.current]
243 else:
244 complete_example=self.dataset[self.current:self.current+self.minibatch_size]
245 if self.fieldnames:
246 return Example(self.fieldnames,list(complete_example))
247 else:
248 return complete_example
249
250 def __init__(self):
251 pass
252
253 def minibatches(self, 227 def minibatches(self,
254 fieldnames = DataSet.minibatches_fieldnames, 228 fieldnames = DataSet.minibatches_fieldnames,
255 minibatch_size = DataSet.minibatches_minibatch_size, 229 minibatch_size = DataSet.minibatches_minibatch_size,
256 n_batches = DataSet.minibatches_n_batches): 230 n_batches = DataSet.minibatches_n_batches):
257 """ 231 """
258 If the fieldnames list is empty, it means that we want to see ALL the fields.
259
260 If the n_batches is empty, we want to see all the examples possible 232 If the n_batches is empty, we want to see all the examples possible
261 for the give minibatch_size. 233 for the given minibatch_size (possibly missing a few at the end of the dataset).
262 """ 234 """
263 # substitute the defaults: 235 # substitute the defaults:
264 if fieldnames is None: fieldnames = self.fieldNames()
265 if n_batches is None: n_batches = len(self) / minibatch_size 236 if n_batches is None: n_batches = len(self) / minibatch_size
266 return DataSet.Iterator(self, fieldnames, minibatch_size, n_batches) 237 return DataSet.Iterator(self, fieldnames, minibatch_size, n_batches)
267 238
268 def __getattr__(self,fieldname):
269 """Return an that can iterate over the values of the field in this dataset."""
270 return self(fieldname)
271
272 def __call__(self,*fieldnames):
273 """Return a sub-dataset containing only the given fieldnames as fields.
274
275 The return value's default iterator will iterate only over the given
276 fields.
277 """
278 raise AbstractFunction()
279
280 def __len__(self):
281 """len(dataset) returns the number of examples in the dataset."""
282 raise AbstractFunction()
283
284 def __getitem__(self,i): 239 def __getitem__(self,i):
285 """dataset[i] returns the (i+1)-th example of the dataset.""" 240 """dataset[i] returns the (i+1)-th example of the dataset."""
286 raise AbstractFunction() 241 raise AbstractFunction()
287 242
288 def __getslice__(self,*slice_args): 243 def __getslice__(self,*slice_args):
289 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" 244 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
290 raise AbstractFunction() 245 raise AbstractFunction()
246
247
248 class FiniteWidthDataSet(DataSet):
249 """
250 Virtual interface for datasets that have a finite width (number of fields),
251 and thus return a list of fieldNames.
252 """
253 def __init__(self):
254 DataSet.__init__(self)
255
256 def hasFields(*fieldnames):
257 has_fields=True
258 for fieldname in fieldnames:
259 if fieldname not in self.fields.keys():
260 has_fields=False
261 return has_fields
262
263 def fieldNames(self):
264 """Return the list of field names that are supported by the iterators,
265 and for which hasFields(fieldname) would return True."""
266 raise AbstractFunction()
267
291 268
292 # we may want ArrayDataSet defined in another python file 269 # we may want ArrayDataSet defined in another python file
293 270
294 import numpy 271 import numpy
295 272
324 # many complicated things remain to be done: 301 # many complicated things remain to be done:
325 # - find common dtype 302 # - find common dtype
326 # - decide what to do with extra dimensions if not the same in all fields 303 # - decide what to do with extra dimensions if not the same in all fields
327 # - try to see if we can avoid the copy? 304 # - try to see if we can avoid the copy?
328 305
329 class ArrayDataSet(FiniteDataSet): 306 class ArrayDataSet(FiniteLengthDataSet,FiniteWidthDataSet,SliceableDataSet):
330 """ 307 """
331 An ArrayDataSet behaves like a numpy array but adds the notion of named fields 308 An ArrayDataSet behaves like a numpy array but adds the notion of named fields
332 from DataSet (and the ability to view the values of multiple fields as an 'Example'). 309 from DataSet (and the ability to view the values of multiple fields as an 'Example').
333 It is a fixed-length and fixed-width dataset 310 It is a fixed-length and fixed-width dataset
334 in which each element is a fixed dimension numpy array or a number, hence the whole 311 in which each element is a fixed dimension numpy array or a number, hence the whole
340 """ 317 """
341 318
342 class Iterator(LookupList): 319 class Iterator(LookupList):
343 """An iterator over a finite dataset that implements wrap-around""" 320 """An iterator over a finite dataset that implements wrap-around"""
344 def __init__(self, dataset, fieldnames, minibatch_size, next_max): 321 def __init__(self, dataset, fieldnames, minibatch_size, next_max):
345 LookupList.__init__(self, fieldnames, [0] * len(fieldnames)) 322 if fieldnames is None: fieldnames = dataset.fieldNames()
323 LookupList.__init__(self, fieldnames, [0]*len(fieldnames))
346 self.dataset=dataset 324 self.dataset=dataset
347 self.minibatch_size=minibatch_size 325 self.minibatch_size=minibatch_size
348 self.next_count = 0 326 self.next_count = 0
349 self.next_max = next_max 327 self.next_max = next_max
350 self.current = -self.minibatch_size 328 self.current = -self.minibatch_size
390 dataview = data[self.current:] 368 dataview = data[self.current:]
391 upper -= rows 369 upper -= rows
392 assert upper > 0 370 assert upper > 0
393 dataview = self.matcat(dataview, data[:upper]) 371 dataview = self.matcat(dataview, data[:upper])
394 372
395
396 self._values = [dataview[:, self.dataset.fields[f]]\ 373 self._values = [dataview[:, self.dataset.fields[f]]\
397 for f in self._names] 374 for f in self._names]
398
399 return self 375 return self
400 376
401 377
402 def __init__(self, data, fields=None): 378 def __init__(self, data, fields=None):
403 """ 379 """
427 def minibatches(self, 403 def minibatches(self,
428 fieldnames = DataSet.minibatches_fieldnames, 404 fieldnames = DataSet.minibatches_fieldnames,
429 minibatch_size = DataSet.minibatches_minibatch_size, 405 minibatch_size = DataSet.minibatches_minibatch_size,
430 n_batches = DataSet.minibatches_n_batches): 406 n_batches = DataSet.minibatches_n_batches):
431 """ 407 """
432 If the fieldnames list is empty, it means that we want to see ALL the fields. 408 If the fieldnames list is None, it means that we want to see ALL the fields.
433 409
434 If the n_batches is empty, we want to see all the examples possible 410 If the n_batches is None, we want to see all the examples possible
435 for the give minibatch_size. 411 for the given minibatch_size (possibly missing some near the end).
436 """ 412 """
437 # substitute the defaults: 413 # substitute the defaults:
438 if fieldnames is None: fieldnames = self.fieldNames()
439 if n_batches is None: n_batches = len(self) / minibatch_size 414 if n_batches is None: n_batches = len(self) / minibatch_size
440 return ArrayDataSet.Iterator(self, fieldnames, minibatch_size, n_batches) 415 return ArrayDataSet.Iterator(self, fieldnames, minibatch_size, n_batches)
441 416
442 def __getattr__(self,fieldname): 417 def __getattr__(self,fieldname):
443 """ 418 """
460 for fieldname,fieldslice in self.fields.items(): 435 for fieldname,fieldslice in self.fields.items():
461 new_fields[fieldname]=slice(fieldslice.start-min_col,fieldslice.stop-min_col,fieldslice.step) 436 new_fields[fieldname]=slice(fieldslice.start-min_col,fieldslice.stop-min_col,fieldslice.step)
462 return ArrayDataSet(self.data[:,min_col:max_col],fields=new_fields) 437 return ArrayDataSet(self.data[:,min_col:max_col],fields=new_fields)
463 438
464 def fieldNames(self): 439 def fieldNames(self):
465 """Return the list of field names that are supported by getattr and getFields.""" 440 """Return the list of field names that are supported by getattr and hasField."""
466 return self.fields.keys() 441 return self.fields.keys()
467 442
468 def __len__(self): 443 def __len__(self):
469 """len(dataset) returns the number of examples in the dataset.""" 444 """len(dataset) returns the number of examples in the dataset."""
470 return len(self.data) 445 return len(self.data)
500 """ 475 """
501 if not self.fields: 476 if not self.fields:
502 return self.data 477 return self.data
503 # else, select subsets of columns mapped by the fields 478 # else, select subsets of columns mapped by the fields
504 columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) 479 columns_used = numpy.zeros((self.data.shape[1]),dtype=bool)
480 overlapping_fields = False
481 n_columns = 0
505 for field_slice in self.fields.values(): 482 for field_slice in self.fields.values():
506 for c in xrange(field_slice.start,field_slice.stop,field_slice.step): 483 for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
484 n_columns += 1
485 if columns_used[c]: overlapping_fields=True
507 columns_used[c]=True 486 columns_used[c]=True
508 # try to figure out if we can map all the slices into one slice: 487 # try to figure out if we can map all the slices into one slice:
509 mappable_to_one_slice = True 488 mappable_to_one_slice = not overlapping_fields
510 start=0 489 if not overlapping_fields:
511 while start<len(columns_used) and not columns_used[start]: 490 start=0
512 start+=1 491 while start<len(columns_used) and not columns_used[start]:
513 stop=len(columns_used) 492 start+=1
514 while stop>0 and not columns_used[stop-1]: 493 stop=len(columns_used)
515 stop-=1 494 while stop>0 and not columns_used[stop-1]:
516 step=0 495 stop-=1
517 i=start 496 step=0
518 while i<stop: 497 i=start
519 j=i+1 498 while i<stop:
520 while j<stop and not columns_used[j]: 499 j=i+1
521 j+=1 500 while j<stop and not columns_used[j]:
522 if step: 501 j+=1
523 if step!=j-i: 502 if step:
524 mappable_to_one_slice = False 503 if step!=j-i:
525 break 504 mappable_to_one_slice = False
526 else: 505 break
527 step = j-i 506 else:
528 i=j 507 step = j-i
508 i=j
529 if mappable_to_one_slice: 509 if mappable_to_one_slice:
530 return self.data[:,slice(start,stop,step)] 510 return self.data[:,slice(start,stop,step)]
531 # else make contiguous copy 511 # else make contiguous copy (copying the overlapping columns)
532 n_columns = sum(columns_used) 512 result = numpy.zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
533 result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
534 print result.shape
535 c=0 513 c=0
536 for field_slice in self.fields.values(): 514 for field_slice in self.fields.values():
537 slice_width=field_slice.stop-field_slice.start/field_slice.step 515 slice_width=(field_slice.stop-field_slice.start)/field_slice.step
538 # copy the field here 516 # copy the field here
539 result[:,slice(c,slice_width)]=self.data[:,field_slice] 517 result[:,slice(c,c+slice_width)]=self.data[:,field_slice]
540 c+=slice_width 518 c+=slice_width
541 return result 519 return result
542 520
543 class ApplyFunctionDataset(DataSet): 521 class ApplyFunctionDataSet(DataSet):
544 """ 522 """
545 A dataset that contains as fields the results of applying 523 A dataset that contains as fields the results of applying
546 a given function (example-wise) to specified input_fields of a source 524 a given function (example-wise) to specified input_fields of a source
547 dataset. The function should return a sequence whose elements will be stored in 525 dataset. The function should return a sequence whose elements will be stored in
548 fields whose names are given in the output_fields list. If copy_inputs 526 fields whose names are given in the output_fields list. If copy_inputs
581 559
582 def minibatches(self, 560 def minibatches(self,
583 fieldnames = DataSet.minibatches_fieldnames, 561 fieldnames = DataSet.minibatches_fieldnames,
584 minibatch_size = DataSet.minibatches_minibatch_size, 562 minibatch_size = DataSet.minibatches_minibatch_size,
585 n_batches = DataSet.minibatches_n_batches): 563 n_batches = DataSet.minibatches_n_batches):
586 564
587 class Iterator(LookupList): 565 class Iterator(LookupList):
588 566
589 def __init__(self,dataset): 567 def __init__(self,dataset):
590 LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) 568 if fieldnames is None:
569 LookupList.__init__(self, [],[])
570 else:
571 LookupList.__init__(self, fieldnames, [0]*len(fieldnames))
591 self.dataset=dataset 572 self.dataset=dataset
592 if dataset.copy_inputs: 573 self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(self.dataset.input_fields))),
593 src_fields=dataset.fieldNames() 574 minibatch_size,n_batches)
594 else:
595 src_fields=dataset.input_fields
596 self.src_iterator=self.src.minibatches(src_fields,minibatch_size,n_batches)
597 575
598 def __iter__(self): 576 def __iter__(self):
599 return self 577 return self
600 578
601 def next(self): 579 def next(self):
602 src_examples = self.src_iterator.next() 580 src_examples = self.src_iterator.next()
603 if self.dataset.copy_inputs: 581 if self.dataset.copy_inputs:
604 function_inputs = src_examples 582 function_inputs = src_examples
605 else: 583 else:
606 function_inputs = 584 function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields]
607 [src_examples[field_name] for field_name in self.dataset.input_fields]) 585 outputs = Example(self.dataset.output_fields,self.dataset.function(*function_inputs))
608 return self.dataset.function(*function_inputs) 586 if self.dataset.copy_inputs:
587 return src_examples + outputs
588 else:
589 return outputs
609 590
610 for fieldname in fieldnames: 591 for fieldname in fieldnames:
611 assert fieldname in self.input_fields 592 assert fieldname in self.output_fields or self.src.hasFields(fieldname)
612 return Iterator(self) 593 return Iterator(self)
613 594
614 595