Mercurial > pylearn
comparison dataset.py @ 44:5a85fda9b19b
Fixed some more iterator bugs
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Mon, 28 Apr 2008 13:52:54 -0400 |
parents | e92244f30116 |
children | a5c70dc42972 |
comparison
equal
deleted
inserted
replaced
43:e92244f30116 | 44:5a85fda9b19b |
---|---|
137 Example field an iterable object over the individual examples in | 137 Example field an iterable object over the individual examples in |
138 the minibatch). | 138 the minibatch). |
139 """ | 139 """ |
140 def __init__(self, minibatch_iterator): | 140 def __init__(self, minibatch_iterator): |
141 self.minibatch_iterator = minibatch_iterator | 141 self.minibatch_iterator = minibatch_iterator |
142 self.minibatch = None | |
142 def __iter__(self): #makes for loop work | 143 def __iter__(self): #makes for loop work |
143 return self | 144 return self |
144 def next(self): | 145 def next(self): |
145 size1_minibatch = self.minibatch_iterator.next() | 146 size1_minibatch = self.minibatch_iterator.next() |
146 return Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()]) | 147 if not self.minibatch: |
148 self.minibatch = Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()]) | |
149 else: | |
150 self.minibatch._values = [value[0] for value in size1_minibatch.values()] | |
151 return self.minibatch | |
147 | 152 |
148 def next_index(self): | 153 def next_index(self): |
149 return self.minibatch_iterator.next_index() | 154 return self.minibatch_iterator.next_index() |
150 | 155 |
151 def __iter__(self): | 156 def __iter__(self): |
474 | 479 |
475 def fieldNames(self): | 480 def fieldNames(self): |
476 return self.fieldnames | 481 return self.fieldnames |
477 | 482 |
478 def __iter__(self): | 483 def __iter__(self): |
479 class Iterator(object): | 484 class FieldsSubsetIterator(object): |
480 def __init__(self,ds): | 485 def __init__(self,ds): |
481 self.ds=ds | 486 self.ds=ds |
482 self.src_iter=ds.src.__iter__() | 487 self.src_iter=ds.src.__iter__() |
488 self.example=None | |
483 def __iter__(self): return self | 489 def __iter__(self): return self |
484 def next(self): | 490 def next(self): |
485 example = self.src_iter.next() | 491 complete_example = self.src_iter.next() |
486 return Example(self.ds.fieldnames, | 492 if self.example: |
487 [example[field] for field in self.ds.fieldnames]) | 493 self.example._values=[complete_example[field] |
488 return Iterator(self) | 494 for field in self.ds.fieldnames] |
495 else: | |
496 self.example=Example(self.ds.fieldnames, | |
497 [complete_example[field] for field in self.ds.fieldnames]) | |
498 return self.example | |
499 return FieldsSubsetIterator(self) | |
489 | 500 |
490 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 501 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
491 assert self.hasFields(*fieldnames) | 502 assert self.hasFields(*fieldnames) |
492 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) | 503 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) |
493 def __getitem__(self,i): | 504 def __getitem__(self,i): |
668 def fieldNames(self): | 679 def fieldNames(self): |
669 return self.fieldname2dataset.keys() | 680 return self.fieldname2dataset.keys() |
670 | 681 |
671 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 682 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
672 | 683 |
673 class Iterator(object): | 684 class HStackedIterator(object): |
674 def __init__(self,hsds,iterators): | 685 def __init__(self,hsds,iterators): |
675 self.hsds=hsds | 686 self.hsds=hsds |
676 self.iterators=iterators | 687 self.iterators=iterators |
677 def __iter__(self): | 688 def __iter__(self): |
678 return self | 689 return self |
698 iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches,offset) | 709 iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches,offset) |
699 for dataset in datasets] | 710 for dataset in datasets] |
700 else: | 711 else: |
701 datasets=self.datasets | 712 datasets=self.datasets |
702 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] | 713 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] |
703 return Iterator(self,iterators) | 714 return HStackedIterator(self,iterators) |
704 | 715 |
705 | 716 |
706 def valuesVStack(self,fieldname,fieldvalues): | 717 def valuesVStack(self,fieldname,fieldvalues): |
707 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) | 718 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) |
708 | 719 |
766 dataset_index = self.index2dataset[row] | 777 dataset_index = self.index2dataset[row] |
767 row_within_dataset = self.datasets_start_row[dataset_index] | 778 row_within_dataset = self.datasets_start_row[dataset_index] |
768 return dataset_index, row_within_dataset | 779 return dataset_index, row_within_dataset |
769 | 780 |
770 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 781 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
771 | 782 |
772 class Iterator(object): | 783 class VStackedIterator(object): |
773 def __init__(self,vsds): | 784 def __init__(self,vsds): |
774 self.vsds=vsds | 785 self.vsds=vsds |
775 self.next_row=offset | 786 self.next_row=offset |
776 self.next_dataset_index,self.next_dataset_row=self.vsds.locate_row(offset) | 787 self.next_dataset_index,self.next_dataset_row=self.vsds.locate_row(offset) |
777 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ | 788 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ |
822 | 833 |
823 self.next_row+=minibatch_size | 834 self.next_row+=minibatch_size |
824 self.next_dataset_row+=minibatch_size | 835 self.next_dataset_row+=minibatch_size |
825 if self.next_row+minibatch_size>len(dataset): | 836 if self.next_row+minibatch_size>len(dataset): |
826 self.move_to_next_dataset() | 837 self.move_to_next_dataset() |
827 return | 838 return examples |
839 return VStackedIterator(self) | |
828 | 840 |
829 class ArrayFieldsDataSet(DataSet): | 841 class ArrayFieldsDataSet(DataSet): |
830 """ | 842 """ |
831 Virtual super-class of datasets whose field values are numpy array, | 843 Virtual super-class of datasets whose field values are numpy array, |
832 thus defining valuesHStack and valuesVStack for sub-classes. | 844 thus defining valuesHStack and valuesVStack for sub-classes. |
884 | 896 |
885 #def __getitem__(self,i): | 897 #def __getitem__(self,i): |
886 # """More efficient implementation than the default""" | 898 # """More efficient implementation than the default""" |
887 | 899 |
888 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 900 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
889 class Iterator(LookupList): # store the result in the lookup-list values | 901 class ArrayDataSetIterator(object): |
890 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 902 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): |
891 if fieldnames is None: fieldnames = dataset.fieldNames() | 903 if fieldnames is None: fieldnames = dataset.fieldNames() |
892 LookupList.__init__(self,fieldnames,[0]*len(fieldnames)) | 904 # store the resulting minibatch in a lookup-list of values |
905 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) | |
893 self.dataset=dataset | 906 self.dataset=dataset |
894 self.minibatch_size=minibatch_size | 907 self.minibatch_size=minibatch_size |
895 assert offset>=0 and offset<len(dataset.data) | 908 assert offset>=0 and offset<len(dataset.data) |
896 assert offset+minibatch_size<=len(dataset.data) | 909 assert offset+minibatch_size<=len(dataset.data) |
897 self.current=offset | 910 self.current=offset |
898 def __iter__(self): | 911 def __iter__(self): |
899 return self | 912 return self |
900 def next(self): | 913 def next(self): |
901 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] | 914 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] |
902 self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names] | 915 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] |
903 self.current+=self.minibatch_size | 916 self.current+=self.minibatch_size |
904 return self | 917 return self.minibatch |
905 | 918 |
906 return Iterator(self,fieldnames,minibatch_size,n_batches,offset) | 919 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) |
907 | 920 |
908 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 921 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
909 """ | 922 """ |
910 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the | 923 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the |
911 user to define a set of fields as the 'input' field and a set of fields | 924 user to define a set of fields as the 'input' field and a set of fields |