Mercurial > pylearn
diff dataset.py @ 44:5a85fda9b19b
Fixed some more iterator bugs
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Mon, 28 Apr 2008 13:52:54 -0400 |
parents | e92244f30116 |
children | a5c70dc42972 |
line wrap: on
line diff
--- a/dataset.py Mon Apr 28 11:41:28 2008 -0400 +++ b/dataset.py Mon Apr 28 13:52:54 2008 -0400 @@ -139,11 +139,16 @@ """ def __init__(self, minibatch_iterator): self.minibatch_iterator = minibatch_iterator + self.minibatch = None def __iter__(self): #makes for loop work return self def next(self): size1_minibatch = self.minibatch_iterator.next() - return Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()]) + if not self.minibatch: + self.minibatch = Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()]) + else: + self.minibatch._values = [value[0] for value in size1_minibatch.values()] + return self.minibatch def next_index(self): return self.minibatch_iterator.next_index() @@ -476,16 +481,22 @@ return self.fieldnames def __iter__(self): - class Iterator(object): + class FieldsSubsetIterator(object): def __init__(self,ds): self.ds=ds self.src_iter=ds.src.__iter__() + self.example=None def __iter__(self): return self def next(self): - example = self.src_iter.next() - return Example(self.ds.fieldnames, - [example[field] for field in self.ds.fieldnames]) - return Iterator(self) + complete_example = self.src_iter.next() + if self.example: + self.example._values=[complete_example[field] + for field in self.ds.fieldnames] + else: + self.example=Example(self.ds.fieldnames, + [complete_example[field] for field in self.ds.fieldnames]) + return self.example + return FieldsSubsetIterator(self) def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): assert self.hasFields(*fieldnames) @@ -670,7 +681,7 @@ def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): - class Iterator(object): + class HStackedIterator(object): def __init__(self,hsds,iterators): self.hsds=hsds self.iterators=iterators @@ -700,7 +711,7 @@ else: datasets=self.datasets iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] - return Iterator(self,iterators) + return HStackedIterator(self,iterators) def valuesVStack(self,fieldname,fieldvalues): @@ -768,8 +779,8 @@ return dataset_index, row_within_dataset def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): - - class Iterator(object): + + class VStackedIterator(object): def __init__(self,vsds): self.vsds=vsds self.next_row=offset @@ -824,7 +835,8 @@ self.next_dataset_row+=minibatch_size if self.next_row+minibatch_size>len(dataset): self.move_to_next_dataset() - return + return examples + return VStackedIterator(self) class ArrayFieldsDataSet(DataSet): """ @@ -886,10 +898,11 @@ # """More efficient implementation than the default""" def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): - class Iterator(LookupList): # store the result in the lookup-list values + class ArrayDataSetIterator(object): def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): if fieldnames is None: fieldnames = dataset.fieldNames() - LookupList.__init__(self,fieldnames,[0]*len(fieldnames)) + # store the resulting minibatch in a lookup-list of values + self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) self.dataset=dataset self.minibatch_size=minibatch_size assert offset>=0 and offset<len(dataset.data) @@ -899,11 +912,11 @@ return self def next(self): sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] - self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names] + self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] self.current+=self.minibatch_size - return self + return self.minibatch - return Iterator(self,fieldnames,minibatch_size,n_batches,offset) + return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): """