Mercurial > pylearn
comparison dataset.py @ 290:9b533cc7874a
trying to get default implemenations to work
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 05 Jun 2008 18:38:42 -0400 |
parents | ed70580f2324 |
children | 174374d59405 |
comparison
equal
deleted
inserted
replaced
276:271a16d42072 | 290:9b533cc7874a |
---|---|
1 | 1 |
2 from lookup_list import LookupList | 2 from lookup_list import LookupList as Example |
3 Example = LookupList | |
4 from misc import unique_elements_list_intersection | 3 from misc import unique_elements_list_intersection |
5 from string import join | 4 from string import join |
6 from sys import maxint | 5 from sys import maxint |
7 import numpy, copy | 6 import numpy, copy |
8 | 7 |
35 attribute_names = self.attributeNames() | 34 attribute_names = self.attributeNames() |
36 if return_copy: | 35 if return_copy: |
37 return [copy.copy(self.__getattribute__(name)) for name in attribute_names] | 36 return [copy.copy(self.__getattribute__(name)) for name in attribute_names] |
38 else: | 37 else: |
39 return [self.__getattribute__(name) for name in attribute_names] | 38 return [self.__getattribute__(name) for name in attribute_names] |
40 | |
41 | 39 |
42 class DataSet(AttributesHolder): | 40 class DataSet(AttributesHolder): |
43 """A virtual base class for datasets. | 41 """A virtual base class for datasets. |
44 | 42 |
45 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction | 43 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction |
232 self.fieldnames=fieldnames | 230 self.fieldnames=fieldnames |
233 self.minibatch_size=minibatch_size | 231 self.minibatch_size=minibatch_size |
234 self.n_batches=n_batches | 232 self.n_batches=n_batches |
235 self.n_batches_done=0 | 233 self.n_batches_done=0 |
236 self.next_row=offset | 234 self.next_row=offset |
237 self.offset=offset | |
238 self.L=len(dataset) | 235 self.L=len(dataset) |
239 assert offset+minibatch_size<=self.L | 236 self.offset=offset % self.L |
240 ds_nbatches = (self.L-self.next_row)/self.minibatch_size | 237 ds_nbatches = (self.L-self.next_row)/self.minibatch_size |
241 if n_batches is not None: | 238 if n_batches is not None: |
242 ds_nbatches = min(n_batches,ds_nbatches) | 239 ds_nbatches = min(n_batches,ds_nbatches) |
243 if fieldnames: | 240 if fieldnames: |
244 assert dataset.hasFields(*fieldnames) | 241 assert dataset.hasFields(*fieldnames) |
245 else: | 242 else: |
246 self.fieldnames=dataset.fieldNames() | 243 self.fieldnames=dataset.fieldNames() |
247 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, | 244 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, ds_nbatches,self.next_row) |
248 ds_nbatches,self.next_row) | |
249 | 245 |
250 def __iter__(self): | 246 def __iter__(self): |
251 return self | 247 return self |
252 | 248 |
253 def next_index(self): | 249 def next_index(self): |
316 Using the first syntax, all the fields will be returned in "i". | 312 Using the first syntax, all the fields will be returned in "i". |
317 Using the third syntax, i1, i2, i3 will be list-like containers of the | 313 Using the third syntax, i1, i2, i3 will be list-like containers of the |
318 f1, f2, and f3 fields of a batch of examples on each loop iteration. | 314 f1, f2, and f3 fields of a batch of examples on each loop iteration. |
319 | 315 |
320 The minibatches iterator is expected to return upon each call to next() | 316 The minibatches iterator is expected to return upon each call to next() |
321 a DataSetFields object, which is a LookupList (indexed by the field names) whose | 317 a DataSetFields object, which is a Example (indexed by the field names) whose |
322 elements are iterable and indexable over the minibatch examples, and which keeps a pointer to | 318 elements are iterable and indexable over the minibatch examples, and which keeps a pointer to |
323 a sub-dataset that can be used to iterate over the individual examples | 319 a sub-dataset that can be used to iterate over the individual examples |
324 in the minibatch. Hence a minibatch can be converted back to a regular | 320 in the minibatch. Hence a minibatch can be converted back to a regular |
325 dataset or its fields can be looked at individually (and possibly iterated over). | 321 dataset or its fields can be looked at individually (and possibly iterated over). |
326 | 322 |
422 assert i in self.__dict__ # else it means we are trying to access a non-existing property | 418 assert i in self.__dict__ # else it means we are trying to access a non-existing property |
423 return self.__dict__[i] | 419 return self.__dict__[i] |
424 | 420 |
425 def __getitem__(self,i): | 421 def __getitem__(self,i): |
426 """ | 422 """ |
427 dataset[i] returns the (i+1)-th example of the dataset. | 423 @rtype: Example |
428 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. | 424 @returns: single or multiple examples |
429 dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. | 425 |
430 dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in. | 426 @type i: integer or slice or <iterable> of integers |
431 dataset['key'] returns a property associated with the given 'key' string. | 427 @param i: |
432 If 'key' is a fieldname, then the VStacked field values (iterable over | 428 dataset[i] returns the (i+1)-th example of the dataset. |
433 field values) for that field is returned. Other keys may be supported | 429 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. |
434 by different dataset subclasses. The following key names are encouraged: | 430 dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. |
435 - 'description': a textual description or name for the dataset | 431 dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in. |
436 - '<fieldname>.type': a type name or value for a given <fieldname> | 432 |
437 | 433 @note: |
438 Note that some stream datasets may be unable to implement random access, i.e. | 434 Some stream datasets may be unable to implement random access, i.e. |
439 arbitrary slicing/indexing | 435 arbitrary slicing/indexing because they can only iterate through |
440 because they can only iterate through examples one or a minibatch at a time | 436 examples one or a minibatch at a time and do not actually store or keep |
441 and do not actually store or keep past (or future) examples. | 437 past (or future) examples. |
442 | 438 |
443 The default implementation of getitem uses the minibatches iterator | 439 The default implementation of getitem uses the minibatches iterator |
444 to obtain one example, one slice, or a list of examples. It may not | 440 to obtain one example, one slice, or a list of examples. It may not |
445 always be the most efficient way to obtain the result, especially if | 441 always be the most efficient way to obtain the result, especially if |
446 the data are actually stored in a memory array. | 442 the data are actually stored in a memory array. |
447 """ | 443 """ |
448 # check for an index | 444 |
449 if type(i) is int: | 445 if type(i) is int: |
450 return DataSet.MinibatchToSingleExampleIterator( | 446 #TODO: consider asserting that i >= 0 |
451 self.minibatches(minibatch_size=1,n_batches=1,offset=i)).next() | 447 i_batch = self.minibatches_nowrap(self.fieldNames(), |
452 rows=None | 448 minibatch_size=1, n_batches=1, offset=i % len(self)) |
453 # or a slice | 449 return DataSet.MinibatchToSingleExampleIterator(i_batch).next() |
450 | |
451 #if i is a contiguous slice | |
452 if type(i) is slice and (i.step in (None, 1)): | |
453 offset = 0 if i.start is None else i.start | |
454 upper_bound = len(self) if i.stop is None else i.stop | |
455 return MinibatchDataSet(self.minibatches_nowrap(self.fieldNames(), | |
456 minibatch_size=upper_bound - offset, | |
457 n_batches=1, | |
458 offset=offset).next()) | |
459 | |
460 # if slice has a step param, convert it to list and handle it with the | |
461 # list code | |
454 if type(i) is slice: | 462 if type(i) is slice: |
455 #print 'i=',i | 463 offset = 0 if i.start is None else i.start |
456 if not i.start: i=slice(0,i.stop,i.step) | 464 upper_bound = len(self) if i.stop is None else i.stop |
457 if not i.stop: i=slice(i.start,len(self),i.step) | 465 i = list(range(offset, upper_bound, i.step)) |
458 if not i.step: i=slice(i.start,i.stop,1) | 466 |
459 if i.step is 1: | 467 # handle tuples, arrays, lists |
460 return self.minibatches(minibatch_size=i.stop-i.start,n_batches=1,offset=i.start).next().examples() | 468 if hasattr(i, '__getitem__'): |
461 rows = range(i.start,i.stop,i.step) | 469 for idx in i: |
462 # or a list of indices | 470 #dis-allow nested slices |
463 elif type(i) is list: | 471 if not isinstance(idx, int): |
464 rows = i | 472 raise TypeError(idx) |
465 if rows is not None: | 473 # call back into self.__getitem__ |
466 examples = [self[row] for row in rows] | 474 examples = [self.minibatches_nowrap(self.fieldNames(), |
467 fields_values = zip(*examples) | 475 minibatch_size=1, n_batches=1, offset=ii%len(self)).next() |
468 return MinibatchDataSet( | 476 for ii in i] |
469 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) | 477 # re-index the fields in each example by field instead of by example |
470 for fieldname,field_values | 478 field_values = [[] for blah in self.fieldNames()] |
471 in zip(self.fieldNames(),fields_values)]), | 479 for e in examples: |
472 self.valuesVStack,self.valuesHStack) | 480 for f,v in zip(field_values, e): |
481 f.append(v) | |
482 #build them into a LookupList (a.ka. Example) | |
483 zz = zip(self.fieldNames(),field_values) | |
484 vst = [self.valuesVStack(fieldname,field_values) for fieldname,field_values in zz] | |
485 example = Example(self.fieldNames(), vst) | |
486 return MinibatchDataSet(example, self.valuesVStack, self.valuesHStack) | |
473 raise TypeError(i, type(i)) | 487 raise TypeError(i, type(i)) |
474 | 488 |
475 def valuesHStack(self,fieldnames,fieldvalues): | 489 def valuesHStack(self,fieldnames,fieldvalues): |
476 """ | 490 """ |
477 Return a value that corresponds to concatenating (horizontally) several field values. | 491 Return a value that corresponds to concatenating (horizontally) several field values. |
491 if all_numpy: | 505 if all_numpy: |
492 return numpy.hstack(fieldvalues) | 506 return numpy.hstack(fieldvalues) |
493 # the default implementation of horizontal stacking is to put values in a list | 507 # the default implementation of horizontal stacking is to put values in a list |
494 return fieldvalues | 508 return fieldvalues |
495 | 509 |
496 | |
497 def valuesVStack(self,fieldname,values): | 510 def valuesVStack(self,fieldname,values): |
498 """ | 511 """ |
499 Return a value that corresponds to concatenating (vertically) several values of the | 512 @param fieldname: the name of the field from which the values were taken |
500 same field. This can be important to build a minibatch out of individual examples. This | 513 @type fieldname: any type |
501 is likely to involve a copy of the original values. When the values are numpy arrays, the | 514 |
502 result should be numpy.vstack(values). | 515 @param values: bits near the beginning or end of the dataset |
503 The default is to use numpy.vstack for numpy.ndarray values, and a list | 516 @type values: list of minibatches (returned by minibatch_nowrap) |
504 pointing to the original values for other data types. | 517 |
505 """ | 518 @return: the concatenation (stacking) of the values |
506 all_numpy=True | 519 @rtype: something suitable as a minibatch field |
507 for value in values: | 520 """ |
508 if not type(value) is numpy.ndarray: | 521 rval = [] |
509 all_numpy=False | 522 for v in values: |
510 if all_numpy: | 523 rval.extend(v) |
511 return numpy.vstack(values) | 524 return rval |
512 # the default implementation of vertical stacking is to put values in a list | |
513 return values | |
514 | 525 |
515 def __or__(self,other): | 526 def __or__(self,other): |
516 """ | 527 """ |
517 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of | 528 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of |
518 fields of the argument datasets. This only works if they all have the same length. | 529 fields of the argument datasets. This only works if they all have the same length. |
584 return FieldsSubsetIterator(self) | 595 return FieldsSubsetIterator(self) |
585 | 596 |
586 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 597 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
587 assert self.hasFields(*fieldnames) | 598 assert self.hasFields(*fieldnames) |
588 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) | 599 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) |
589 def __getitem__(self,i): | 600 def dontuse__getitem__(self,i): |
590 return FieldsSubsetDataSet(self.src[i],self.fieldnames) | 601 return FieldsSubsetDataSet(self.src[i],self.fieldnames) |
591 | 602 |
592 | 603 |
593 class DataSetFields(LookupList): | 604 class DataSetFields(Example): |
594 """ | 605 """ |
595 Although a L{DataSet} iterates over examples (like rows of a matrix), an associated | 606 Although a L{DataSet} iterates over examples (like rows of a matrix), an associated |
596 DataSetFields iterates over fields (like columns of a matrix), and can be understood | 607 DataSetFields iterates over fields (like columns of a matrix), and can be understood |
597 as a transpose of the associated dataset. | 608 as a transpose of the associated dataset. |
598 | 609 |
626 dataset = FieldsSubsetDataSet(dataset,fieldnames) | 637 dataset = FieldsSubsetDataSet(dataset,fieldnames) |
627 assert dataset.hasFields(*fieldnames) | 638 assert dataset.hasFields(*fieldnames) |
628 self.dataset=dataset | 639 self.dataset=dataset |
629 | 640 |
630 if isinstance(dataset,MinibatchDataSet): | 641 if isinstance(dataset,MinibatchDataSet): |
631 LookupList.__init__(self,fieldnames,list(dataset._fields)) | 642 Example.__init__(self,fieldnames,list(dataset._fields)) |
632 elif isinstance(original_dataset,MinibatchDataSet): | 643 elif isinstance(original_dataset,MinibatchDataSet): |
633 LookupList.__init__(self,fieldnames, | 644 Example.__init__(self,fieldnames, |
634 [original_dataset._fields[field] | 645 [original_dataset._fields[field] |
635 for field in fieldnames]) | 646 for field in fieldnames]) |
636 else: | 647 else: |
637 minibatch_iterator = dataset.minibatches(fieldnames, | 648 minibatch_iterator = dataset.minibatches(fieldnames, |
638 minibatch_size=len(dataset), | 649 minibatch_size=len(dataset), |
639 n_batches=1) | 650 n_batches=1) |
640 minibatch=minibatch_iterator.next() | 651 minibatch=minibatch_iterator.next() |
641 LookupList.__init__(self,fieldnames,minibatch) | 652 Example.__init__(self,fieldnames,minibatch) |
642 | 653 |
643 def examples(self): | 654 def examples(self): |
644 return self.dataset | 655 return self.dataset |
645 | 656 |
646 def __or__(self,other): | 657 def __or__(self,other): |
658 return (self.examples() | other.examples()).fields() | 669 return (self.examples() | other.examples()).fields() |
659 | 670 |
660 | 671 |
661 class MinibatchDataSet(DataSet): | 672 class MinibatchDataSet(DataSet): |
662 """ | 673 """ |
663 Turn a L{LookupList} of same-length (iterable) fields into an example-iterable dataset. | 674 Turn a L{Example} of same-length (iterable) fields into an example-iterable dataset. |
664 Each element of the lookup-list should be an iterable and sliceable, all of the same length. | 675 Each element of the lookup-list should be an iterable and sliceable, all of the same length. |
665 """ | 676 """ |
666 def __init__(self,fields_lookuplist,values_vstack=DataSet().valuesVStack, | 677 def __init__(self,fields_lookuplist,values_vstack=DataSet().valuesVStack, |
667 values_hstack=DataSet().valuesHStack): | 678 values_hstack=DataSet().valuesHStack): |
668 """ | 679 """ |
678 if self.length != len(field) : | 689 if self.length != len(field) : |
679 print 'self.length = ',self.length | 690 print 'self.length = ',self.length |
680 print 'len(field) = ', len(field) | 691 print 'len(field) = ', len(field) |
681 print 'self._fields.keys() = ', self._fields.keys() | 692 print 'self._fields.keys() = ', self._fields.keys() |
682 print 'field=',field | 693 print 'field=',field |
694 print 'fields_lookuplist=', fields_lookuplist | |
683 assert self.length==len(field) | 695 assert self.length==len(field) |
684 self.values_vstack=values_vstack | 696 self.valuesVStack=values_vstack |
685 self.values_hstack=values_hstack | 697 self.valuesHStack=values_hstack |
686 | 698 |
687 def __len__(self): | 699 def __len__(self): |
688 return self.length | 700 return self.length |
689 | 701 |
690 def __getitem__(self,i): | 702 def dontuse__getitem__(self,i): |
691 if type(i) in (slice,list): | 703 if type(i) in (slice,list): |
692 return DataSetFields(MinibatchDataSet( | 704 return DataSetFields(MinibatchDataSet( |
693 Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames()) | 705 Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames()) |
694 if type(i) is int: | 706 if type(i) is int: |
695 return Example(self._fields.keys(),[field[i] for field in self._fields]) | 707 return Example(self._fields.keys(),[field[i] for field in self._fields]) |
715 if fieldnames is None: fieldnames = ds._fields.keys() | 727 if fieldnames is None: fieldnames = ds._fields.keys() |
716 self.fieldnames = fieldnames | 728 self.fieldnames = fieldnames |
717 | 729 |
718 self.ds=ds | 730 self.ds=ds |
719 self.next_example=offset | 731 self.next_example=offset |
720 assert minibatch_size > 0 | 732 assert minibatch_size >= 0 |
721 if offset+minibatch_size > ds.length: | 733 if offset+minibatch_size > ds.length: |
722 raise NotImplementedError() | 734 raise NotImplementedError() |
723 def __iter__(self): | 735 def __iter__(self): |
724 return self | 736 return self |
725 def next(self): | 737 def next(self): |
739 return minibatch | 751 return minibatch |
740 | 752 |
741 # tbm: added fieldnames to handle subset of fieldnames | 753 # tbm: added fieldnames to handle subset of fieldnames |
742 return Iterator(self,fieldnames) | 754 return Iterator(self,fieldnames) |
743 | 755 |
744 def valuesVStack(self,fieldname,fieldvalues): | |
745 return self.values_vstack(fieldname,fieldvalues) | |
746 | |
747 def valuesHStack(self,fieldnames,fieldvalues): | |
748 return self.values_hstack(fieldnames,fieldvalues) | |
749 | |
750 class HStackedDataSet(DataSet): | 756 class HStackedDataSet(DataSet): |
751 """ | 757 """ |
752 A L{DataSet} that wraps several datasets and shows a view that includes all their fields, | 758 A L{DataSet} that wraps several datasets and shows a view that includes all their fields, |
753 i.e. whose list of fields is the concatenation of their lists of fields. | 759 i.e. whose list of fields is the concatenation of their lists of fields. |
754 | 760 |
808 self.iterators=iterators | 814 self.iterators=iterators |
809 def __iter__(self): | 815 def __iter__(self): |
810 return self | 816 return self |
811 def next(self): | 817 def next(self): |
812 # concatenate all the fields of the minibatches | 818 # concatenate all the fields of the minibatches |
813 l=LookupList() | 819 l=Example() |
814 for iter in self.iterators: | 820 for iter in self.iterators: |
815 l.append_lookuplist(iter.next()) | 821 l.append_lookuplist(iter.next()) |
816 return l | 822 return l |
817 | 823 |
818 assert self.hasFields(*fieldnames) | 824 assert self.hasFields(*fieldnames) |
832 datasets=self.datasets | 838 datasets=self.datasets |
833 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] | 839 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] |
834 return HStackedIterator(self,iterators) | 840 return HStackedIterator(self,iterators) |
835 | 841 |
836 | 842 |
837 def valuesVStack(self,fieldname,fieldvalues): | 843 def untested_valuesVStack(self,fieldname,fieldvalues): |
838 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) | 844 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) |
839 | 845 |
840 def valuesHStack(self,fieldnames,fieldvalues): | 846 def untested_valuesHStack(self,fieldnames,fieldvalues): |
841 """ | 847 """ |
842 We will use the sub-dataset associated with the first fieldname in the fieldnames list | 848 We will use the sub-dataset associated with the first fieldname in the fieldnames list |
843 to do the work, hoping that it can cope with the other values (i.e. won't care | 849 to do the work, hoping that it can cope with the other values (i.e. won't care |
844 about the incompatible fieldnames). Hence this heuristic will always work if | 850 about the incompatible fieldnames). Hence this heuristic will always work if |
845 all the fieldnames are of the same sub-dataset. | 851 all the fieldnames are of the same sub-dataset. |
959 Virtual super-class of datasets whose field values are numpy array, | 965 Virtual super-class of datasets whose field values are numpy array, |
960 thus defining valuesHStack and valuesVStack for sub-classes. | 966 thus defining valuesHStack and valuesVStack for sub-classes. |
961 """ | 967 """ |
962 def __init__(self,description=None,field_types=None): | 968 def __init__(self,description=None,field_types=None): |
963 DataSet.__init__(self,description,field_types) | 969 DataSet.__init__(self,description,field_types) |
964 def valuesHStack(self,fieldnames,fieldvalues): | 970 def untested_valuesHStack(self,fieldnames,fieldvalues): |
965 """Concatenate field values horizontally, e.g. two vectors | 971 """Concatenate field values horizontally, e.g. two vectors |
966 become a longer vector, two matrices become a wider matrix, etc.""" | 972 become a longer vector, two matrices become a wider matrix, etc.""" |
967 return numpy.hstack(fieldvalues) | 973 return numpy.hstack(fieldvalues) |
968 def valuesVStack(self,fieldname,values): | 974 def untested_valuesVStack(self,fieldname,values): |
969 """Concatenate field values vertically, e.g. two vectors | 975 """Concatenate field values vertically, e.g. two vectors |
970 become a two-row matrix, two matrices become a longer matrix, etc.""" | 976 become a two-row matrix, two matrices become a longer matrix, etc.""" |
971 return numpy.vstack(values) | 977 return numpy.vstack(values) |
972 | 978 |
973 class ArrayDataSet(ArrayFieldsDataSet): | 979 class ArrayDataSet(ArrayFieldsDataSet): |
1017 return self.fields_columns.keys() | 1023 return self.fields_columns.keys() |
1018 | 1024 |
1019 def __len__(self): | 1025 def __len__(self): |
1020 return len(self.data) | 1026 return len(self.data) |
1021 | 1027 |
1022 def __getitem__(self,key): | 1028 def dontuse__getitem__(self,key): |
1023 """More efficient implementation than the default __getitem__""" | 1029 """More efficient implementation than the default __getitem__""" |
1024 fieldnames=self.fields_columns.keys() | 1030 fieldnames=self.fields_columns.keys() |
1025 values=self.fields_columns.values() | 1031 values=self.fields_columns.values() |
1026 if type(key) is int: | 1032 if type(key) is int: |
1027 return Example(fieldnames, | 1033 return Example(fieldnames, |
1049 return self.data[:,self.fields_columns[key]] | 1055 return self.data[:,self.fields_columns[key]] |
1050 # else we are trying to access a property of the dataset | 1056 # else we are trying to access a property of the dataset |
1051 assert key in self.__dict__ # else it means we are trying to access a non-existing property | 1057 assert key in self.__dict__ # else it means we are trying to access a non-existing property |
1052 return self.__dict__[key] | 1058 return self.__dict__[key] |
1053 | 1059 |
1054 def __iter__(self): | 1060 def dontuse__iter__(self): |
1055 class ArrayDataSetIteratorIter(object): | 1061 class ArrayDataSetIteratorIter(object): |
1056 def __init__(self,dataset,fieldnames): | 1062 def __init__(self,dataset,fieldnames): |
1057 if fieldnames is None: fieldnames = dataset.fieldNames() | 1063 if fieldnames is None: fieldnames = dataset.fieldNames() |
1058 # store the resulting minibatch in a lookup-list of values | 1064 # store the resulting minibatch in a lookup-list of values |
1059 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) | 1065 self.minibatch = Example(fieldnames,[0]*len(fieldnames)) |
1060 self.dataset=dataset | 1066 self.dataset=dataset |
1061 self.current=0 | 1067 self.current=0 |
1062 self.columns = [self.dataset.fields_columns[f] | 1068 self.columns = [self.dataset.fields_columns[f] |
1063 for f in self.minibatch._names] | 1069 for f in self.minibatch._names] |
1064 self.l = self.dataset.data.shape[0] | 1070 self.l = self.dataset.data.shape[0] |
1076 return self.minibatch | 1082 return self.minibatch |
1077 | 1083 |
1078 return ArrayDataSetIteratorIter(self,self.fieldNames()) | 1084 return ArrayDataSetIteratorIter(self,self.fieldNames()) |
1079 | 1085 |
1080 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1086 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1081 class ArrayDataSetIterator(object): | 1087 cursor = Example(fieldnames,[0]*len(fieldnames)) |
1082 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 1088 fieldnames = self.fieldNames() if fieldnames is None else fieldnames |
1083 if fieldnames is None: fieldnames = dataset.fieldNames() | 1089 for n in xrange(n_batches): |
1084 # store the resulting minibatch in a lookup-list of values | 1090 if offset == len(self): |
1085 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) | 1091 break |
1086 self.dataset=dataset | 1092 sub_data = self.data[offset : offset+minibatch_size] |
1087 self.minibatch_size=minibatch_size | 1093 offset += len(sub_data) #can be less than minibatch_size at end |
1088 assert offset>=0 and offset<len(dataset.data) | 1094 cursor._values = [sub_data[:,self.fields_columns[f]] for f in cursor._names] |
1089 assert offset+minibatch_size<=len(dataset.data) | 1095 yield cursor |
1090 self.current=offset | 1096 |
1091 def __iter__(self): | 1097 #return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) |
1092 return self | |
1093 def next(self): | |
1094 #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator | |
1095 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] | |
1096 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] | |
1097 self.current+=self.minibatch_size | |
1098 return self.minibatch | |
1099 | |
1100 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) | |
1101 | 1098 |
1102 | 1099 |
1103 class CachedDataSet(DataSet): | 1100 class CachedDataSet(DataSet): |
1104 """ | 1101 """ |
1105 Wrap a L{DataSet} whose values are computationally expensive to obtain | 1102 Wrap a L{DataSet} whose values are computationally expensive to obtain |
1160 if self.all_fields: | 1157 if self.all_fields: |
1161 return all_fields_minibatch | 1158 return all_fields_minibatch |
1162 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | 1159 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) |
1163 return CacheIterator(self) | 1160 return CacheIterator(self) |
1164 | 1161 |
1165 def __getitem__(self,i): | 1162 def dontuse__getitem__(self,i): |
1166 if type(i)==int and len(self.cached_examples)>i: | 1163 if type(i)==int and len(self.cached_examples)>i: |
1167 return self.cached_examples[i] | 1164 return self.cached_examples[i] |
1168 else: | 1165 else: |
1169 return self.source_dataset[i] | 1166 return self.source_dataset[i] |
1170 | 1167 |
1173 def __init__(self,dataset): | 1170 def __init__(self,dataset): |
1174 self.dataset=dataset | 1171 self.dataset=dataset |
1175 self.l = len(dataset) | 1172 self.l = len(dataset) |
1176 self.current = 0 | 1173 self.current = 0 |
1177 self.fieldnames = self.dataset.fieldNames() | 1174 self.fieldnames = self.dataset.fieldNames() |
1178 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) | 1175 self.example = Example(self.fieldnames,[0]*len(self.fieldnames)) |
1179 def __iter__(self): return self | 1176 def __iter__(self): return self |
1180 def next(self): | 1177 def next(self): |
1181 if self.current>=self.l: | 1178 if self.current>=self.l: |
1182 raise StopIteration | 1179 raise StopIteration |
1183 cache_len = len(self.dataset.cached_examples) | 1180 cache_len = len(self.dataset.cached_examples) |
1190 return self.example | 1187 return self.example |
1191 | 1188 |
1192 return CacheIteratorIter(self) | 1189 return CacheIteratorIter(self) |
1193 | 1190 |
1194 class ApplyFunctionDataSet(DataSet): | 1191 class ApplyFunctionDataSet(DataSet): |
1195 """ | 1192 """ |
1196 A L{DataSet} that contains as fields the results of applying a | 1193 A L{DataSet} that contains as fields the results of applying a |
1197 given function example-wise or minibatch-wise to all the fields of | 1194 given function example-wise or minibatch-wise to all the fields of |
1198 an input dataset. The output of the function should be an iterable | 1195 an input dataset. The output of the function should be an iterable |
1199 (e.g. a list or a LookupList) over the resulting values. | 1196 (e.g. a list or a Example) over the resulting values. |
1200 | 1197 |
1201 The function take as input the fields of the dataset, not the examples. | 1198 The function take as input the fields of the dataset, not the examples. |
1202 | 1199 |
1203 In minibatch mode, the function is expected to work on minibatches | 1200 In minibatch mode, the function is expected to work on minibatches |
1204 (takes a minibatch in input and returns a minibatch in output). More | 1201 (takes a minibatch in input and returns a minibatch in output). More |
1205 precisely, it means that each element of the input or output list | 1202 precisely, it means that each element of the input or output list |
1206 should be iterable and indexable over the individual example values | 1203 should be iterable and indexable over the individual example values |
1207 (typically these elements will be numpy arrays). All of the elements | 1204 (typically these elements will be numpy arrays). All of the elements |
1208 in the input and output lists should have the same length, which is | 1205 in the input and output lists should have the same length, which is |
1209 the length of the minibatch. | 1206 the length of the minibatch. |
1210 | 1207 |
1211 The function is applied each time an example or a minibatch is accessed. | 1208 The function is applied each time an example or a minibatch is accessed. |
1212 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. | 1209 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. |
1213 | 1210 |
1214 If the values_{h,v}stack functions are not provided, then | 1211 If the values_{h,v}stack functions are not provided, then |
1215 the input_dataset.values{H,V}Stack functions are used by default. | 1212 the input_dataset.values{H,V}Stack functions are used by default. |
1216 """ | 1213 """ |
1217 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, | 1214 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, |
1218 values_hstack=None,values_vstack=None, | 1215 values_hstack=None,values_vstack=None, |
1219 description=None,fieldtypes=None): | 1216 description=None,fieldtypes=None): |
1220 """ | 1217 """ |
1221 Constructor takes an input dataset that has as many fields as the function | 1218 Constructor takes an input dataset that has as many fields as the function |
1222 expects as inputs. The resulting dataset has as many fields as the function | 1219 expects as inputs. The resulting dataset has as many fields as the function |
1223 produces as outputs, and that should correspond to the number of output names | 1220 produces as outputs, and that should correspond to the number of output names |
1224 (provided in a list). | 1221 (provided in a list). |
1225 | 1222 |
1226 Note that the expected semantics of the function differs in minibatch mode | 1223 Note that the expected semantics of the function differs in minibatch mode |
1227 (it takes minibatches of inputs and produces minibatches of outputs, as | 1224 (it takes minibatches of inputs and produces minibatches of outputs, as |
1228 documented in the class comment). | 1225 documented in the class comment). |
1229 | 1226 |
1230 TBM: are filedtypes the old field types (from input_dataset) or the new ones | 1227 TBM: are filedtypes the old field types (from input_dataset) or the new ones |
1231 (for the new dataset created)? | 1228 (for the new dataset created)? |
1232 """ | 1229 """ |
1233 self.input_dataset=input_dataset | 1230 self.input_dataset=input_dataset |
1234 self.function=function | 1231 self.function=function |
1235 self.output_names=output_names | 1232 self.output_names=output_names |
1236 self.minibatch_mode=minibatch_mode | 1233 self.minibatch_mode=minibatch_mode |
1237 DataSet.__init__(self,description,fieldtypes) | 1234 DataSet.__init__(self,description,fieldtypes) |
1238 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack | 1235 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack |
1239 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack | 1236 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack |
1240 | 1237 |
1241 def __len__(self): | 1238 def __len__(self): |
1242 return len(self.input_dataset) | 1239 return len(self.input_dataset) |
1243 | 1240 |
1244 def fieldNames(self): | 1241 def fieldNames(self): |
1245 return self.output_names | 1242 return self.output_names |
1246 | 1243 |
1247 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1244 def minibatches_nowrap(self, fieldnames, *args, **kwargs): |
1248 class ApplyFunctionIterator(object): | 1245 for fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): |
1249 def __init__(self,output_dataset): | 1246 |
1250 self.input_dataset=output_dataset.input_dataset | 1247 #function_inputs = self.input_iterator.next() |
1251 self.output_dataset=output_dataset | 1248 if self.minibatch_mode: |
1252 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, | 1249 function_outputs = self.function(*fields) |
1253 n_batches=n_batches,offset=offset).__iter__() | 1250 else: |
1254 | 1251 input_examples = zip(*fields) |
1255 def __iter__(self): return self | 1252 output_examples = [self.function(*input_example) |
1256 | 1253 for input_example in input_examples] |
1257 def next(self): | 1254 function_outputs = [self.valuesVStack(name,values) |
1258 function_inputs = self.input_iterator.next() | 1255 for name,values in zip(self.output_names, |
1259 all_output_names = self.output_dataset.output_names | 1256 zip(*output_examples))] |
1260 if self.output_dataset.minibatch_mode: | 1257 all_outputs = Example(self.output_names, function_outputs) |
1261 function_outputs = self.output_dataset.function(*function_inputs) | 1258 print fields |
1262 else: | 1259 print all_outputs |
1263 input_examples = zip(*function_inputs) | 1260 print '--------' |
1264 output_examples = [self.output_dataset.function(*input_example) | 1261 if fieldnames==self.output_names: |
1265 for input_example in input_examples] | 1262 yield all_outputs |
1266 function_outputs = [self.output_dataset.valuesVStack(name,values) | 1263 else: |
1267 for name,values in zip(all_output_names, | 1264 yield Example(fieldnames,[all_outputs[name] for name in fieldnames]) |
1268 zip(*output_examples))] | 1265 |
1269 all_outputs = Example(all_output_names,function_outputs) | 1266 def untested__iter__(self): # only implemented for increased efficiency |
1270 if fieldnames==all_output_names: | 1267 class ApplyFunctionSingleExampleIterator(object): |
1271 return all_outputs | 1268 def __init__(self,output_dataset): |
1272 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) | 1269 self.current=0 |
1273 | 1270 self.output_dataset=output_dataset |
1274 | 1271 self.input_iterator=output_dataset.input_dataset.__iter__() |
1275 return ApplyFunctionIterator(self) | 1272 def __iter__(self): return self |
1276 | 1273 def next(self): |
1277 def __iter__(self): # only implemented for increased efficiency | 1274 if self.output_dataset.minibatch_mode: |
1278 class ApplyFunctionSingleExampleIterator(object): | 1275 function_inputs = [[input] for input in self.input_iterator.next()] |
1279 def __init__(self,output_dataset): | 1276 outputs = self.output_dataset.function(*function_inputs) |
1280 self.current=0 | 1277 assert all([hasattr(output,'__iter__') for output in outputs]) |
1281 self.output_dataset=output_dataset | 1278 function_outputs = [output[0] for output in outputs] |
1282 self.input_iterator=output_dataset.input_dataset.__iter__() | 1279 else: |
1283 def __iter__(self): return self | 1280 function_inputs = self.input_iterator.next() |
1284 def next(self): | 1281 function_outputs = self.output_dataset.function(*function_inputs) |
1285 if self.output_dataset.minibatch_mode: | 1282 return Example(self.output_dataset.output_names,function_outputs) |
1286 function_inputs = [[input] for input in self.input_iterator.next()] | 1283 return ApplyFunctionSingleExampleIterator(self) |
1287 outputs = self.output_dataset.function(*function_inputs) | 1284 |
1288 assert all([hasattr(output,'__iter__') for output in outputs]) | |
1289 function_outputs = [output[0] for output in outputs] | |
1290 else: | |
1291 function_inputs = self.input_iterator.next() | |
1292 function_outputs = self.output_dataset.function(*function_inputs) | |
1293 return Example(self.output_dataset.output_names,function_outputs) | |
1294 return ApplyFunctionSingleExampleIterator(self) | |
1295 | |
1296 | 1285 |
1297 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 1286 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
1298 """ | 1287 """ |
1299 Wraps an arbitrary L{DataSet} into one for supervised learning tasks | 1288 Wraps an arbitrary L{DataSet} into one for supervised learning tasks |
1300 by forcing the user to define a set of fields as the 'input' field | 1289 by forcing the user to define a set of fields as the 'input' field |