Mercurial > pylearn
comparison dataset.py @ 23:526e192b0699
Working on ApplyFunctionDataSet, added constraint that
DataSet iterators must have a next_index() method.
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Wed, 09 Apr 2008 18:27:13 -0400 |
parents | b6b36f65664f |
children | 672fe4b23032 |
comparison
equal
deleted
inserted
replaced
22:b6b36f65664f | 23:526e192b0699 |
---|---|
15 To iterate over examples, there are several possibilities: | 15 To iterate over examples, there are several possibilities: |
16 - for example in dataset.zip([field1, field2,field3, ...]) | 16 - for example in dataset.zip([field1, field2,field3, ...]) |
17 - for val1,val2,val3 in dataset.zip([field1, field2,field3]) | 17 - for val1,val2,val3 in dataset.zip([field1, field2,field3]) |
18 - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N) | 18 - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N) |
19 - for example in dataset | 19 - for example in dataset |
20 Each of these is documented below. | 20 Each of these is documented below. All of these iterators are expected |
21 to provide, in addition to the usual 'next()' method, a 'next_index()' method | |
22 which returns a non-negative integer pointing to the position of the next | |
23 example that will be returned by 'next()' (or of the first example in the | |
24 next minibatch returned). This is important because these iterators | |
25 can wrap around the dataset in order to do multiple passes through it, | |
26 in possibly unregular ways if the minibatch size is not a divisor of the | |
27 dataset length. | |
21 | 28 |
22 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. | 29 Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. |
23 | 30 |
24 Note: The content of a field can be of any type. | 31 Note: The content of a field can be of any type. |
25 | 32 |
38 """ | 45 """ |
39 | 46 |
40 def __init__(self): | 47 def __init__(self): |
41 pass | 48 pass |
42 | 49 |
43 class Iter(LookupList): | 50 class Iterator(LookupList): |
44 def __init__(self, ll): | 51 def __init__(self, ll): |
45 LookupList.__init__(self, ll.keys(), ll.values()) | 52 LookupList.__init__(self, ll.keys(), ll.values()) |
46 self.ll = ll | 53 self.ll = ll |
47 def __iter__(self): #makes for loop work | 54 def __iter__(self): #makes for loop work |
48 return self | 55 return self |
49 def next(self): | 56 def next(self): |
50 self.ll.next() | 57 self.ll.next() |
51 self._values = [v[0] for v in self.ll._values] | 58 self._values = [v[0] for v in self.ll._values] |
52 return self | 59 return self |
60 def next_index(self): | |
61 return self.ll.next_index() | |
53 | 62 |
54 def __iter__(self): | 63 def __iter__(self): |
55 """Supports the syntax "for i in dataset: ..." | 64 """Supports the syntax "for i in dataset: ..." |
56 | 65 |
57 Using this syntax, "i" will be an Example instance (or equivalent) with | 66 Using this syntax, "i" will be an Example instance (or equivalent) with |
59 a field of a single example. Fields should be accessible via | 68 a field of a single example. Fields should be accessible via |
60 i["fielname"] or i[3] (in the order defined by the elements of the | 69 i["fielname"] or i[3] (in the order defined by the elements of the |
61 Example returned by this iterator), but the derived class is free | 70 Example returned by this iterator), but the derived class is free |
62 to accept any type of identifier, and add extra functionality to the iterator. | 71 to accept any type of identifier, and add extra functionality to the iterator. |
63 """ | 72 """ |
64 return DataSet.Iter(self.minibatches(None, minibatch_size = 1)) | 73 return DataSet.Iterator(self.minibatches(None, minibatch_size = 1)) |
65 | 74 |
66 def zip(self, *fieldnames): | 75 def zip(self, *fieldnames): |
67 """ | 76 """ |
68 Supports two forms of syntax: | 77 Supports two forms of syntax: |
69 | 78 |
79 f1, f2, and f3 fields of a single example on each loop iteration. | 88 f1, f2, and f3 fields of a single example on each loop iteration. |
80 | 89 |
81 The derived class may accept fieldname arguments of any type. | 90 The derived class may accept fieldname arguments of any type. |
82 | 91 |
83 """ | 92 """ |
84 return DataSet.Iter(self.minibatches(fieldnames, minibatch_size = 1)) | 93 return DataSet.Iterator(self.minibatches(fieldnames, minibatch_size = 1)) |
85 | 94 |
86 minibatches_fieldnames = None | 95 minibatches_fieldnames = None |
87 minibatches_minibatch_size = 1 | 96 minibatches_minibatch_size = 1 |
88 minibatches_n_batches = None | 97 minibatches_n_batches = None |
89 def minibatches(self, | 98 def minibatches(self, |
139 given) is recognized by the DataSet (i.e. can be used as a field name in one | 148 given) is recognized by the DataSet (i.e. can be used as a field name in one |
140 of the iterators). | 149 of the iterators). |
141 """ | 150 """ |
142 raise AbstractFunction() | 151 raise AbstractFunction() |
143 | 152 |
144 def rename(*new_field_specifications): | 153 def merge_fields(*specifications): |
145 #Yoshua- | |
146 # Do you mean for this to be a virtual method? | |
147 # Wouldn't this functionality be easier to provide via a | |
148 # RenamingDataSet, such as the one I've written below? | |
149 # -JB | |
150 # You are right. Whichever implementation, however, we need a generic way to | |
151 # 'concatenate' fields, to handle the ([old_field1, old_field2, ...], new_field) semantics. | |
152 # -YB | |
153 """ | 154 """ |
154 Return a new dataset that maps old fields (of self) to new fields (of the returned | 155 Return a new dataset that maps old fields (of self) to new fields (of the returned |
155 dataset). The minimal syntax that should be supported is the following: | 156 dataset). The minimal syntax that should be supported is the following: |
156 new_field_specifications = [new_field_spec1, new_field_spec2, ...] | 157 new_field_specifications = [new_field_spec1, new_field_spec2, ...] |
157 new_field_spec = ([old_field1, old_field2, ...], new_field) | 158 new_field_spec = ([old_field1, old_field2, ...], new_field) |
159 support additional indexing schemes within each field (e.g. column slice | 160 support additional indexing schemes within each field (e.g. column slice |
160 of a matrix-like field). | 161 of a matrix-like field). |
161 """ | 162 """ |
162 raise AbstractFunction() | 163 raise AbstractFunction() |
163 | 164 |
165 def merge_field_values(*field_value_pairs) | |
166 """ | |
167 Return the value that corresponds to merging the values of several fields, | |
168 given as arguments (field_name, field_value) pairs with self.hasField(field_name). | |
169 This may be used by implementations of merge_fields. | |
170 Raise a ValueError if the operation is not possible. | |
171 """ | |
172 fieldnames,fieldvalues = zip(*field_value_pairs) | |
173 raise ValueError("Unable to merge values of these fields:"+repr(fieldnames)) | |
174 | |
175 def examples2minibatch(examples): | |
176 """ | |
177 Combine a list of Examples into a minibatch. A minibatch is an Example whose fields | |
178 are iterable over the examples of the minibatch. | |
179 """ | |
180 raise AbstractFunction() | |
181 | |
182 def rename(rename_dict): | |
183 """ | |
184 Return a new dataset that renames fields, using a dictionnary that maps old field | |
185 names to new field names. The only fields visible by the returned dataset are those | |
186 whose names are keys of the rename_dict. | |
187 """ | |
188 return RenamingDataSet(self,rename_dict) | |
164 | 189 |
165 def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): | 190 def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): |
166 """ | 191 """ |
167 Return a dataset that contains as fields the results of applying | 192 Return a dataset that contains as fields the results of applying |
168 the given function (example-wise) to the specified input_fields. The | 193 the given function (example-wise) to the specified input_fields. The |
276 # dataset supports an as_array_dataset member function, and return that if | 301 # dataset supports an as_array_dataset member function, and return that if |
277 # possible. | 302 # possible. |
278 if hasattr(dataset, 'as_array_dataset'): | 303 if hasattr(dataset, 'as_array_dataset'): |
279 return dataset.as_array_dataset() | 304 return dataset.as_array_dataset() |
280 | 305 |
281 raise NotImplementedError() | 306 raise NotImplementedError |
282 | 307 |
283 # Make ONE big minibatch with all the examples, to separate the fields. | 308 # Make ONE big minibatch with all the examples, to separate the fields. |
284 n_examples = len(dataset) | 309 n_examples = len(dataset) |
285 batch = dataset.minibatches( minibatch_size = len(dataset)).next() | 310 batch = dataset.minibatches( minibatch_size = len(dataset)).next() |
286 | 311 |
341 assert a.dtype is b.dtype | 366 assert a.dtype is b.dtype |
342 rval = numpy.empty( (a0 + b0, a1), dtype=a.dtype) | 367 rval = numpy.empty( (a0 + b0, a1), dtype=a.dtype) |
343 rval[:a0,:] = a | 368 rval[:a0,:] = a |
344 rval[a0:,:] = b | 369 rval[a0:,:] = b |
345 return rval | 370 return rval |
371 | |
372 def next_index(self): | |
373 n_rows = self.dataset.data.shape[0] | |
374 next_i = self.current+self.minibatch_size | |
375 if next_i >= n_rows: | |
376 next_i -= n_rows | |
377 return next_i | |
346 | 378 |
347 def next(self): | 379 def next(self): |
348 | 380 |
349 #check for end-of-loop | 381 #check for end-of-loop |
350 self.next_count += 1 | 382 self.next_count += 1 |
351 if self.next_count == self.next_max: | 383 if self.next_count == self.next_max: |
352 raise StopIteration | 384 raise StopIteration |
353 | 385 |
354 #determine the first and last elements of the slice we'll return | 386 #determine the first and last elements of the slice we'll return |
355 rows = self.dataset.data.shape[0] | 387 n_rows = self.dataset.data.shape[0] |
356 self.current += self.minibatch_size | 388 self.current = self.next_index() |
357 if self.current >= rows: | |
358 self.current -= rows | |
359 upper = self.current + self.minibatch_size | 389 upper = self.current + self.minibatch_size |
360 | 390 |
361 data = self.dataset.data | 391 data = self.dataset.data |
362 | 392 |
363 if upper <= rows: | 393 if upper <= n_rows: |
364 #this is the easy case, we only need once slice | 394 #this is the easy case, we only need once slice |
365 dataview = data[self.current:upper] | 395 dataview = data[self.current:upper] |
366 else: | 396 else: |
367 # the minibatch wraps around the end of the dataset | 397 # the minibatch wraps around the end of the dataset |
368 dataview = data[self.current:] | 398 dataview = data[self.current:] |
369 upper -= rows | 399 upper -= n_rows |
370 assert upper > 0 | 400 assert upper > 0 |
371 dataview = self.matcat(dataview, data[:upper]) | 401 dataview = self.matcat(dataview, data[:upper]) |
372 | 402 |
373 self._values = [dataview[:, self.dataset.fields[f]]\ | 403 self._values = [dataview[:, self.dataset.fields[f]]\ |
374 for f in self._names] | 404 for f in self._names] |
516 # copy the field here | 546 # copy the field here |
517 result[:,slice(c,c+slice_width)]=self.data[:,field_slice] | 547 result[:,slice(c,c+slice_width)]=self.data[:,field_slice] |
518 c+=slice_width | 548 c+=slice_width |
519 return result | 549 return result |
520 | 550 |
551 def rename(*new_field_specifications): | |
552 """ | |
553 Return a new dataset that maps old fields (of self) to new fields (of the returned | |
554 dataset). The minimal syntax that should be supported is the following: | |
555 new_field_specifications = [new_field_spec1, new_field_spec2, ...] | |
556 new_field_spec = ([old_field1, old_field2, ...], new_field) | |
557 In general both old_field and new_field should be strings, but some datasets may also | |
558 support additional indexing schemes within each field (e.g. column slice | |
559 of a matrix-like field). | |
560 """ | |
561 # if all old fields of each spec are | |
562 raise NotImplementedError() | |
563 | |
521 class ApplyFunctionDataSet(DataSet): | 564 class ApplyFunctionDataSet(DataSet): |
522 """ | 565 """ |
523 A dataset that contains as fields the results of applying | 566 A dataset that contains as fields the results of applying |
524 a given function (example-wise) to specified input_fields of a source | 567 a given function (example-wise) to specified input_fields of a source |
525 dataset. The function should return a sequence whose elements will be stored in | 568 dataset. The function should return a sequence whose elements will be stored in |
530 iterator). In any case, the computations may be delayed until the examples | 573 iterator). In any case, the computations may be delayed until the examples |
531 of self are requested. If cache is True, then | 574 of self are requested. If cache is True, then |
532 once the output fields for some examples have been computed, then | 575 once the output fields for some examples have been computed, then |
533 are cached (to avoid recomputation if the same examples are again requested). | 576 are cached (to avoid recomputation if the same examples are again requested). |
534 """ | 577 """ |
535 def __init__(src,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): | 578 def __init__(src,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True, compute_now=False): |
536 DataSet.__init__(self) | 579 DataSet.__init__(self) |
537 self.src=src | 580 self.src=src |
538 self.function=function | 581 self.function=function |
582 assert src.hasFields(input_fields) | |
539 self.input_fields=input_fields | 583 self.input_fields=input_fields |
540 self.output_fields=output_fields | 584 self.output_fields=output_fields |
585 assert not (copy_inputs and compute_now and not hasattr(src,'fieldNames')) | |
541 self.copy_inputs=copy_inputs | 586 self.copy_inputs=copy_inputs |
542 self.accept_minibatches=accept_minibatches | 587 self.accept_minibatches=accept_minibatches |
543 src_fieldnames = src.fieldNames() | |
544 if copy_inputs: | |
545 for src_field in src_fieldnames: | |
546 assert src_field not in output_fields | |
547 self.fieldnames=src_fieldnames+output_fields | |
548 else: | |
549 self.fieldnames=output_fields | |
550 for input_field in input_fields: | |
551 assert input_field in src_fieldnames | |
552 self.cache=cache | 588 self.cache=cache |
553 if cache: | 589 self.compute_now=compute_now |
590 if compute_now: | |
591 assert hasattr(src,'__len__') and len(src)>=0 | |
592 fieldnames = output_fields | |
593 if copy_inputs: fieldnames = src.fieldNames() + output_fields | |
594 if accept_minibatches: | |
595 # make a single minibatch with all the inputs | |
596 inputs = src.minibatches(input_fields,len(src)).next() | |
597 # and apply the function to it, and transpose into a list of examples (field values, actually) | |
598 self.cached_examples = zip(*Example(output_fields,function(*inputs))) | |
599 else: | |
600 # compute a list with one tuple per example, with the function outputs | |
601 self.cached_examples = [ function(input) for input in src.zip(input_fields) ] | |
602 else if cache: | |
554 # maybe a fixed-size array kind of structure would be more efficient than a list | 603 # maybe a fixed-size array kind of structure would be more efficient than a list |
555 # in the case where src is FiniteDataSet. -YB | 604 # in the case where src is FiniteDataSet. -YB |
556 self.cached_examples = [] | 605 self.cached_examples = [] |
557 | 606 |
558 def fieldNames(self): return self.fieldnames | |
559 | |
560 def minibatches(self, | 607 def minibatches(self, |
561 fieldnames = DataSet.minibatches_fieldnames, | 608 fieldnames = DataSet.minibatches_fieldnames, |
562 minibatch_size = DataSet.minibatches_minibatch_size, | 609 minibatch_size = DataSet.minibatches_minibatch_size, |
563 n_batches = DataSet.minibatches_n_batches): | 610 n_batches = DataSet.minibatches_n_batches): |
564 | 611 |
565 class Iterator(LookupList): | 612 class Iterator(LookupList): |
566 | 613 |
567 def __init__(self,dataset): | 614 def __init__(self,dataset): |
568 if fieldnames is None: | 615 if fieldnames is None: |
569 LookupList.__init__(self, [],[]) | 616 assert hasattr(dataset,"fieldNames") |
570 else: | 617 fieldnames = dataset.fieldNames() |
571 LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) | 618 self.example_index=0 |
619 LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) | |
572 self.dataset=dataset | 620 self.dataset=dataset |
573 self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(self.dataset.input_fields))), | 621 self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(dataset.input_fields))), |
574 minibatch_size,n_batches) | 622 minibatch_size,n_batches) |
623 self.fieldnames_not_in_input = [] | |
624 if self.copy_inputs: | |
625 self.fieldnames_not_in_input = filter(lambda x: not x in dataset.input_fields, fieldnames) | |
575 | 626 |
576 def __iter__(self): | 627 def __iter__(self): |
577 return self | 628 return self |
578 | 629 |
630 def next_index(self): | |
631 return self.src_iterator.next_index() | |
632 | |
579 def next(self): | 633 def next(self): |
634 example_index = self.src_iterator.next_index() | |
580 src_examples = self.src_iterator.next() | 635 src_examples = self.src_iterator.next() |
581 if self.dataset.copy_inputs: | 636 if self.dataset.copy_inputs: |
637 function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields] | |
638 else: | |
582 function_inputs = src_examples | 639 function_inputs = src_examples |
640 if self.dataset.cached_examples: | |
641 cache_len=len(self.cached_examples) | |
642 if example_index<cache_len+minibatch_size: | |
643 outputs_list = self.cached_examples[example_index:example_index+minibatch_size] | |
644 # convert the minibatch list of examples | |
645 # into a list of fields each of which iterate over the minibatch | |
646 outputs = zip(*outputs_list) | |
647 else: | |
648 outputs = self.dataset.function(*function_inputs) | |
649 if self.dataset.cache: | |
650 # convert the list of fields, each of which can iterate over the minibatch | |
651 # into a list of examples in the minibatch (each of which is a list of field values) | |
652 outputs_list = zip(*outputs) | |
653 # copy the outputs_list into the cache | |
654 for i in xrange(cache_len,example_index): | |
655 self.cached_examples.append(None) | |
656 self.cached_examples += outputs_list | |
583 else: | 657 else: |
584 function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields] | 658 outputs = self.dataset.function(*function_inputs) |
585 outputs = Example(self.dataset.output_fields,self.dataset.function(*function_inputs)) | 659 |
586 if self.dataset.copy_inputs: | 660 return Example(self.fieldnames_not_in_input+self.dataset.output_fields, |
587 return src_examples + outputs | 661 [src_examples[field_name] for field_name in self.fieldnames_not_in_input]+outputs) |
588 else: | 662 |
589 return outputs | |
590 | 663 |
591 for fieldname in fieldnames: | 664 for fieldname in fieldnames: |
592 assert fieldname in self.output_fields or self.src.hasFields(fieldname) | 665 assert fieldname in self.output_fields or self.src.hasFields(fieldname) |
593 return Iterator(self) | 666 return Iterator(self) |
594 | 667 |
595 | 668 |
669 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | |
670 """ | |
671 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the | |
672 user to define a set of fields as the 'input' field and a set of fields | |
673 as the 'target' field. Optionally, a single weight_field can also be defined. | |
674 """ | |
675 args = ((input_fields,'input'),(output_fields,'target')) | |
676 if weight_field: args+=(([weight_field],'weight')) | |
677 return src_dataset.rename(*args) | |
678 | |
679 | |
680 | |
681 |