comparison dataset.py @ 23:526e192b0699

Working on ApplyFunctionDataSet, added constraint that DataSet iterators must have a next_index() method.
author bengioy@esprit.iro.umontreal.ca
date Wed, 09 Apr 2008 18:27:13 -0400
parents b6b36f65664f
children 672fe4b23032
comparison
equal deleted inserted replaced
22:b6b36f65664f 23:526e192b0699
15 To iterate over examples, there are several possibilities: 15 To iterate over examples, there are several possibilities:
16 - for example in dataset.zip([field1, field2,field3, ...]) 16 - for example in dataset.zip([field1, field2,field3, ...])
17 - for val1,val2,val3 in dataset.zip([field1, field2,field3]) 17 - for val1,val2,val3 in dataset.zip([field1, field2,field3])
18 - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N) 18 - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N)
19 - for example in dataset 19 - for example in dataset
20 Each of these is documented below. 20 Each of these is documented below. All of these iterators are expected
21 to provide, in addition to the usual 'next()' method, a 'next_index()' method
22 which returns a non-negative integer pointing to the position of the next
23 example that will be returned by 'next()' (or of the first example in the
24 next minibatch returned). This is important because these iterators
25 can wrap around the dataset in order to do multiple passes through it,
26 in possibly unregular ways if the minibatch size is not a divisor of the
27 dataset length.
21 28
22 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. 29 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content.
23 30
24 Note: The content of a field can be of any type. 31 Note: The content of a field can be of any type.
25 32
38 """ 45 """
39 46
40 def __init__(self): 47 def __init__(self):
41 pass 48 pass
42 49
43 class Iter(LookupList): 50 class Iterator(LookupList):
44 def __init__(self, ll): 51 def __init__(self, ll):
45 LookupList.__init__(self, ll.keys(), ll.values()) 52 LookupList.__init__(self, ll.keys(), ll.values())
46 self.ll = ll 53 self.ll = ll
47 def __iter__(self): #makes for loop work 54 def __iter__(self): #makes for loop work
48 return self 55 return self
49 def next(self): 56 def next(self):
50 self.ll.next() 57 self.ll.next()
51 self._values = [v[0] for v in self.ll._values] 58 self._values = [v[0] for v in self.ll._values]
52 return self 59 return self
60 def next_index(self):
61 return self.ll.next_index()
53 62
54 def __iter__(self): 63 def __iter__(self):
55 """Supports the syntax "for i in dataset: ..." 64 """Supports the syntax "for i in dataset: ..."
56 65
57 Using this syntax, "i" will be an Example instance (or equivalent) with 66 Using this syntax, "i" will be an Example instance (or equivalent) with
59 a field of a single example. Fields should be accessible via 68 a field of a single example. Fields should be accessible via
60 i["fielname"] or i[3] (in the order defined by the elements of the 69 i["fielname"] or i[3] (in the order defined by the elements of the
61 Example returned by this iterator), but the derived class is free 70 Example returned by this iterator), but the derived class is free
62 to accept any type of identifier, and add extra functionality to the iterator. 71 to accept any type of identifier, and add extra functionality to the iterator.
63 """ 72 """
64 return DataSet.Iter(self.minibatches(None, minibatch_size = 1)) 73 return DataSet.Iterator(self.minibatches(None, minibatch_size = 1))
65 74
66 def zip(self, *fieldnames): 75 def zip(self, *fieldnames):
67 """ 76 """
68 Supports two forms of syntax: 77 Supports two forms of syntax:
69 78
79 f1, f2, and f3 fields of a single example on each loop iteration. 88 f1, f2, and f3 fields of a single example on each loop iteration.
80 89
81 The derived class may accept fieldname arguments of any type. 90 The derived class may accept fieldname arguments of any type.
82 91
83 """ 92 """
84 return DataSet.Iter(self.minibatches(fieldnames, minibatch_size = 1)) 93 return DataSet.Iterator(self.minibatches(fieldnames, minibatch_size = 1))
85 94
86 minibatches_fieldnames = None 95 minibatches_fieldnames = None
87 minibatches_minibatch_size = 1 96 minibatches_minibatch_size = 1
88 minibatches_n_batches = None 97 minibatches_n_batches = None
89 def minibatches(self, 98 def minibatches(self,
139 given) is recognized by the DataSet (i.e. can be used as a field name in one 148 given) is recognized by the DataSet (i.e. can be used as a field name in one
140 of the iterators). 149 of the iterators).
141 """ 150 """
142 raise AbstractFunction() 151 raise AbstractFunction()
143 152
144 def rename(*new_field_specifications): 153 def merge_fields(*specifications):
145 #Yoshua-
146 # Do you mean for this to be a virtual method?
147 # Wouldn't this functionality be easier to provide via a
148 # RenamingDataSet, such as the one I've written below?
149 # -JB
150 # You are right. Whichever implementation, however, we need a generic way to
151 # 'concatenate' fields, to handle the ([old_field1, old_field2, ...], new_field) semantics.
152 # -YB
153 """ 154 """
154 Return a new dataset that maps old fields (of self) to new fields (of the returned 155 Return a new dataset that maps old fields (of self) to new fields (of the returned
155 dataset). The minimal syntax that should be supported is the following: 156 dataset). The minimal syntax that should be supported is the following:
156 new_field_specifications = [new_field_spec1, new_field_spec2, ...] 157 new_field_specifications = [new_field_spec1, new_field_spec2, ...]
157 new_field_spec = ([old_field1, old_field2, ...], new_field) 158 new_field_spec = ([old_field1, old_field2, ...], new_field)
159 support additional indexing schemes within each field (e.g. column slice 160 support additional indexing schemes within each field (e.g. column slice
160 of a matrix-like field). 161 of a matrix-like field).
161 """ 162 """
162 raise AbstractFunction() 163 raise AbstractFunction()
163 164
165 def merge_field_values(*field_value_pairs)
166 """
167 Return the value that corresponds to merging the values of several fields,
168 given as arguments (field_name, field_value) pairs with self.hasField(field_name).
169 This may be used by implementations of merge_fields.
170 Raise a ValueError if the operation is not possible.
171 """
172 fieldnames,fieldvalues = zip(*field_value_pairs)
173 raise ValueError("Unable to merge values of these fields:"+repr(fieldnames))
174
175 def examples2minibatch(examples):
176 """
177 Combine a list of Examples into a minibatch. A minibatch is an Example whose fields
178 are iterable over the examples of the minibatch.
179 """
180 raise AbstractFunction()
181
182 def rename(rename_dict):
183 """
184 Return a new dataset that renames fields, using a dictionnary that maps old field
185 names to new field names. The only fields visible by the returned dataset are those
186 whose names are keys of the rename_dict.
187 """
188 return RenamingDataSet(self,rename_dict)
164 189
165 def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): 190 def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True):
166 """ 191 """
167 Return a dataset that contains as fields the results of applying 192 Return a dataset that contains as fields the results of applying
168 the given function (example-wise) to the specified input_fields. The 193 the given function (example-wise) to the specified input_fields. The
276 # dataset supports an as_array_dataset member function, and return that if 301 # dataset supports an as_array_dataset member function, and return that if
277 # possible. 302 # possible.
278 if hasattr(dataset, 'as_array_dataset'): 303 if hasattr(dataset, 'as_array_dataset'):
279 return dataset.as_array_dataset() 304 return dataset.as_array_dataset()
280 305
281 raise NotImplementedError() 306 raise NotImplementedError
282 307
283 # Make ONE big minibatch with all the examples, to separate the fields. 308 # Make ONE big minibatch with all the examples, to separate the fields.
284 n_examples = len(dataset) 309 n_examples = len(dataset)
285 batch = dataset.minibatches( minibatch_size = len(dataset)).next() 310 batch = dataset.minibatches( minibatch_size = len(dataset)).next()
286 311
341 assert a.dtype is b.dtype 366 assert a.dtype is b.dtype
342 rval = numpy.empty( (a0 + b0, a1), dtype=a.dtype) 367 rval = numpy.empty( (a0 + b0, a1), dtype=a.dtype)
343 rval[:a0,:] = a 368 rval[:a0,:] = a
344 rval[a0:,:] = b 369 rval[a0:,:] = b
345 return rval 370 return rval
371
372 def next_index(self):
373 n_rows = self.dataset.data.shape[0]
374 next_i = self.current+self.minibatch_size
375 if next_i >= n_rows:
376 next_i -= n_rows
377 return next_i
346 378
347 def next(self): 379 def next(self):
348 380
349 #check for end-of-loop 381 #check for end-of-loop
350 self.next_count += 1 382 self.next_count += 1
351 if self.next_count == self.next_max: 383 if self.next_count == self.next_max:
352 raise StopIteration 384 raise StopIteration
353 385
354 #determine the first and last elements of the slice we'll return 386 #determine the first and last elements of the slice we'll return
355 rows = self.dataset.data.shape[0] 387 n_rows = self.dataset.data.shape[0]
356 self.current += self.minibatch_size 388 self.current = self.next_index()
357 if self.current >= rows:
358 self.current -= rows
359 upper = self.current + self.minibatch_size 389 upper = self.current + self.minibatch_size
360 390
361 data = self.dataset.data 391 data = self.dataset.data
362 392
363 if upper <= rows: 393 if upper <= n_rows:
364 #this is the easy case, we only need once slice 394 #this is the easy case, we only need once slice
365 dataview = data[self.current:upper] 395 dataview = data[self.current:upper]
366 else: 396 else:
367 # the minibatch wraps around the end of the dataset 397 # the minibatch wraps around the end of the dataset
368 dataview = data[self.current:] 398 dataview = data[self.current:]
369 upper -= rows 399 upper -= n_rows
370 assert upper > 0 400 assert upper > 0
371 dataview = self.matcat(dataview, data[:upper]) 401 dataview = self.matcat(dataview, data[:upper])
372 402
373 self._values = [dataview[:, self.dataset.fields[f]]\ 403 self._values = [dataview[:, self.dataset.fields[f]]\
374 for f in self._names] 404 for f in self._names]
516 # copy the field here 546 # copy the field here
517 result[:,slice(c,c+slice_width)]=self.data[:,field_slice] 547 result[:,slice(c,c+slice_width)]=self.data[:,field_slice]
518 c+=slice_width 548 c+=slice_width
519 return result 549 return result
520 550
551 def rename(*new_field_specifications):
552 """
553 Return a new dataset that maps old fields (of self) to new fields (of the returned
554 dataset). The minimal syntax that should be supported is the following:
555 new_field_specifications = [new_field_spec1, new_field_spec2, ...]
556 new_field_spec = ([old_field1, old_field2, ...], new_field)
557 In general both old_field and new_field should be strings, but some datasets may also
558 support additional indexing schemes within each field (e.g. column slice
559 of a matrix-like field).
560 """
561 # if all old fields of each spec are
562 raise NotImplementedError()
563
521 class ApplyFunctionDataSet(DataSet): 564 class ApplyFunctionDataSet(DataSet):
522 """ 565 """
523 A dataset that contains as fields the results of applying 566 A dataset that contains as fields the results of applying
524 a given function (example-wise) to specified input_fields of a source 567 a given function (example-wise) to specified input_fields of a source
525 dataset. The function should return a sequence whose elements will be stored in 568 dataset. The function should return a sequence whose elements will be stored in
530 iterator). In any case, the computations may be delayed until the examples 573 iterator). In any case, the computations may be delayed until the examples
531 of self are requested. If cache is True, then 574 of self are requested. If cache is True, then
532 once the output fields for some examples have been computed, then 575 once the output fields for some examples have been computed, then
533 are cached (to avoid recomputation if the same examples are again requested). 576 are cached (to avoid recomputation if the same examples are again requested).
534 """ 577 """
535 def __init__(src,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): 578 def __init__(src,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True, compute_now=False):
536 DataSet.__init__(self) 579 DataSet.__init__(self)
537 self.src=src 580 self.src=src
538 self.function=function 581 self.function=function
582 assert src.hasFields(input_fields)
539 self.input_fields=input_fields 583 self.input_fields=input_fields
540 self.output_fields=output_fields 584 self.output_fields=output_fields
585 assert not (copy_inputs and compute_now and not hasattr(src,'fieldNames'))
541 self.copy_inputs=copy_inputs 586 self.copy_inputs=copy_inputs
542 self.accept_minibatches=accept_minibatches 587 self.accept_minibatches=accept_minibatches
543 src_fieldnames = src.fieldNames()
544 if copy_inputs:
545 for src_field in src_fieldnames:
546 assert src_field not in output_fields
547 self.fieldnames=src_fieldnames+output_fields
548 else:
549 self.fieldnames=output_fields
550 for input_field in input_fields:
551 assert input_field in src_fieldnames
552 self.cache=cache 588 self.cache=cache
553 if cache: 589 self.compute_now=compute_now
590 if compute_now:
591 assert hasattr(src,'__len__') and len(src)>=0
592 fieldnames = output_fields
593 if copy_inputs: fieldnames = src.fieldNames() + output_fields
594 if accept_minibatches:
595 # make a single minibatch with all the inputs
596 inputs = src.minibatches(input_fields,len(src)).next()
597 # and apply the function to it, and transpose into a list of examples (field values, actually)
598 self.cached_examples = zip(*Example(output_fields,function(*inputs)))
599 else:
600 # compute a list with one tuple per example, with the function outputs
601 self.cached_examples = [ function(input) for input in src.zip(input_fields) ]
602 else if cache:
554 # maybe a fixed-size array kind of structure would be more efficient than a list 603 # maybe a fixed-size array kind of structure would be more efficient than a list
555 # in the case where src is FiniteDataSet. -YB 604 # in the case where src is FiniteDataSet. -YB
556 self.cached_examples = [] 605 self.cached_examples = []
557 606
558 def fieldNames(self): return self.fieldnames
559
560 def minibatches(self, 607 def minibatches(self,
561 fieldnames = DataSet.minibatches_fieldnames, 608 fieldnames = DataSet.minibatches_fieldnames,
562 minibatch_size = DataSet.minibatches_minibatch_size, 609 minibatch_size = DataSet.minibatches_minibatch_size,
563 n_batches = DataSet.minibatches_n_batches): 610 n_batches = DataSet.minibatches_n_batches):
564 611
565 class Iterator(LookupList): 612 class Iterator(LookupList):
566 613
567 def __init__(self,dataset): 614 def __init__(self,dataset):
568 if fieldnames is None: 615 if fieldnames is None:
569 LookupList.__init__(self, [],[]) 616 assert hasattr(dataset,"fieldNames")
570 else: 617 fieldnames = dataset.fieldNames()
571 LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) 618 self.example_index=0
619 LookupList.__init__(self, fieldnames, [0]*len(fieldnames))
572 self.dataset=dataset 620 self.dataset=dataset
573 self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(self.dataset.input_fields))), 621 self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(dataset.input_fields))),
574 minibatch_size,n_batches) 622 minibatch_size,n_batches)
623 self.fieldnames_not_in_input = []
624 if self.copy_inputs:
625 self.fieldnames_not_in_input = filter(lambda x: not x in dataset.input_fields, fieldnames)
575 626
576 def __iter__(self): 627 def __iter__(self):
577 return self 628 return self
578 629
630 def next_index(self):
631 return self.src_iterator.next_index()
632
579 def next(self): 633 def next(self):
634 example_index = self.src_iterator.next_index()
580 src_examples = self.src_iterator.next() 635 src_examples = self.src_iterator.next()
581 if self.dataset.copy_inputs: 636 if self.dataset.copy_inputs:
637 function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields]
638 else:
582 function_inputs = src_examples 639 function_inputs = src_examples
640 if self.dataset.cached_examples:
641 cache_len=len(self.cached_examples)
642 if example_index<cache_len+minibatch_size:
643 outputs_list = self.cached_examples[example_index:example_index+minibatch_size]
644 # convert the minibatch list of examples
645 # into a list of fields each of which iterate over the minibatch
646 outputs = zip(*outputs_list)
647 else:
648 outputs = self.dataset.function(*function_inputs)
649 if self.dataset.cache:
650 # convert the list of fields, each of which can iterate over the minibatch
651 # into a list of examples in the minibatch (each of which is a list of field values)
652 outputs_list = zip(*outputs)
653 # copy the outputs_list into the cache
654 for i in xrange(cache_len,example_index):
655 self.cached_examples.append(None)
656 self.cached_examples += outputs_list
583 else: 657 else:
584 function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields] 658 outputs = self.dataset.function(*function_inputs)
585 outputs = Example(self.dataset.output_fields,self.dataset.function(*function_inputs)) 659
586 if self.dataset.copy_inputs: 660 return Example(self.fieldnames_not_in_input+self.dataset.output_fields,
587 return src_examples + outputs 661 [src_examples[field_name] for field_name in self.fieldnames_not_in_input]+outputs)
588 else: 662
589 return outputs
590 663
591 for fieldname in fieldnames: 664 for fieldname in fieldnames:
592 assert fieldname in self.output_fields or self.src.hasFields(fieldname) 665 assert fieldname in self.output_fields or self.src.hasFields(fieldname)
593 return Iterator(self) 666 return Iterator(self)
594 667
595 668
669 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
670 """
671 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the
672 user to define a set of fields as the 'input' field and a set of fields
673 as the 'target' field. Optionally, a single weight_field can also be defined.
674 """
675 args = ((input_fields,'input'),(output_fields,'target'))
676 if weight_field: args+=(([weight_field],'weight'))
677 return src_dataset.rename(*args)
678
679
680
681