comparison dataset.py @ 44:5a85fda9b19b

Fixed some more iterator bugs
author bengioy@grenat.iro.umontreal.ca
date Mon, 28 Apr 2008 13:52:54 -0400
parents e92244f30116
children a5c70dc42972
comparison
equal deleted inserted replaced
43:e92244f30116 44:5a85fda9b19b
137 Example field an iterable object over the individual examples in 137 Example field an iterable object over the individual examples in
138 the minibatch). 138 the minibatch).
139 """ 139 """
140 def __init__(self, minibatch_iterator): 140 def __init__(self, minibatch_iterator):
141 self.minibatch_iterator = minibatch_iterator 141 self.minibatch_iterator = minibatch_iterator
142 self.minibatch = None
142 def __iter__(self): #makes for loop work 143 def __iter__(self): #makes for loop work
143 return self 144 return self
144 def next(self): 145 def next(self):
145 size1_minibatch = self.minibatch_iterator.next() 146 size1_minibatch = self.minibatch_iterator.next()
146 return Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()]) 147 if not self.minibatch:
148 self.minibatch = Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()])
149 else:
150 self.minibatch._values = [value[0] for value in size1_minibatch.values()]
151 return self.minibatch
147 152
148 def next_index(self): 153 def next_index(self):
149 return self.minibatch_iterator.next_index() 154 return self.minibatch_iterator.next_index()
150 155
151 def __iter__(self): 156 def __iter__(self):
474 479
475 def fieldNames(self): 480 def fieldNames(self):
476 return self.fieldnames 481 return self.fieldnames
477 482
478 def __iter__(self): 483 def __iter__(self):
479 class Iterator(object): 484 class FieldsSubsetIterator(object):
480 def __init__(self,ds): 485 def __init__(self,ds):
481 self.ds=ds 486 self.ds=ds
482 self.src_iter=ds.src.__iter__() 487 self.src_iter=ds.src.__iter__()
488 self.example=None
483 def __iter__(self): return self 489 def __iter__(self): return self
484 def next(self): 490 def next(self):
485 example = self.src_iter.next() 491 complete_example = self.src_iter.next()
486 return Example(self.ds.fieldnames, 492 if self.example:
487 [example[field] for field in self.ds.fieldnames]) 493 self.example._values=[complete_example[field]
488 return Iterator(self) 494 for field in self.ds.fieldnames]
495 else:
496 self.example=Example(self.ds.fieldnames,
497 [complete_example[field] for field in self.ds.fieldnames])
498 return self.example
499 return FieldsSubsetIterator(self)
489 500
490 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 501 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
491 assert self.hasFields(*fieldnames) 502 assert self.hasFields(*fieldnames)
492 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) 503 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset)
493 def __getitem__(self,i): 504 def __getitem__(self,i):
668 def fieldNames(self): 679 def fieldNames(self):
669 return self.fieldname2dataset.keys() 680 return self.fieldname2dataset.keys()
670 681
671 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 682 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
672 683
673 class Iterator(object): 684 class HStackedIterator(object):
674 def __init__(self,hsds,iterators): 685 def __init__(self,hsds,iterators):
675 self.hsds=hsds 686 self.hsds=hsds
676 self.iterators=iterators 687 self.iterators=iterators
677 def __iter__(self): 688 def __iter__(self):
678 return self 689 return self
698 iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches,offset) 709 iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches,offset)
699 for dataset in datasets] 710 for dataset in datasets]
700 else: 711 else:
701 datasets=self.datasets 712 datasets=self.datasets
702 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] 713 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets]
703 return Iterator(self,iterators) 714 return HStackedIterator(self,iterators)
704 715
705 716
706 def valuesVStack(self,fieldname,fieldvalues): 717 def valuesVStack(self,fieldname,fieldvalues):
707 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) 718 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues)
708 719
766 dataset_index = self.index2dataset[row] 777 dataset_index = self.index2dataset[row]
767 row_within_dataset = self.datasets_start_row[dataset_index] 778 row_within_dataset = self.datasets_start_row[dataset_index]
768 return dataset_index, row_within_dataset 779 return dataset_index, row_within_dataset
769 780
770 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 781 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
771 782
772 class Iterator(object): 783 class VStackedIterator(object):
773 def __init__(self,vsds): 784 def __init__(self,vsds):
774 self.vsds=vsds 785 self.vsds=vsds
775 self.next_row=offset 786 self.next_row=offset
776 self.next_dataset_index,self.next_dataset_row=self.vsds.locate_row(offset) 787 self.next_dataset_index,self.next_dataset_row=self.vsds.locate_row(offset)
777 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ 788 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
822 833
823 self.next_row+=minibatch_size 834 self.next_row+=minibatch_size
824 self.next_dataset_row+=minibatch_size 835 self.next_dataset_row+=minibatch_size
825 if self.next_row+minibatch_size>len(dataset): 836 if self.next_row+minibatch_size>len(dataset):
826 self.move_to_next_dataset() 837 self.move_to_next_dataset()
827 return 838 return examples
839 return VStackedIterator(self)
828 840
829 class ArrayFieldsDataSet(DataSet): 841 class ArrayFieldsDataSet(DataSet):
830 """ 842 """
831 Virtual super-class of datasets whose field values are numpy array, 843 Virtual super-class of datasets whose field values are numpy array,
832 thus defining valuesHStack and valuesVStack for sub-classes. 844 thus defining valuesHStack and valuesVStack for sub-classes.
884 896
885 #def __getitem__(self,i): 897 #def __getitem__(self,i):
886 # """More efficient implementation than the default""" 898 # """More efficient implementation than the default"""
887 899
888 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 900 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
889 class Iterator(LookupList): # store the result in the lookup-list values 901 class ArrayDataSetIterator(object):
890 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 902 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
891 if fieldnames is None: fieldnames = dataset.fieldNames() 903 if fieldnames is None: fieldnames = dataset.fieldNames()
892 LookupList.__init__(self,fieldnames,[0]*len(fieldnames)) 904 # store the resulting minibatch in a lookup-list of values
905 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
893 self.dataset=dataset 906 self.dataset=dataset
894 self.minibatch_size=minibatch_size 907 self.minibatch_size=minibatch_size
895 assert offset>=0 and offset<len(dataset.data) 908 assert offset>=0 and offset<len(dataset.data)
896 assert offset+minibatch_size<=len(dataset.data) 909 assert offset+minibatch_size<=len(dataset.data)
897 self.current=offset 910 self.current=offset
898 def __iter__(self): 911 def __iter__(self):
899 return self 912 return self
900 def next(self): 913 def next(self):
901 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] 914 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size]
902 self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names] 915 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names]
903 self.current+=self.minibatch_size 916 self.current+=self.minibatch_size
904 return self 917 return self.minibatch
905 918
906 return Iterator(self,fieldnames,minibatch_size,n_batches,offset) 919 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
907 920
908 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 921 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
909 """ 922 """
910 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the 923 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the
911 user to define a set of fields as the 'input' field and a set of fields 924 user to define a set of fields as the 'input' field and a set of fields