# HG changeset patch # User bengioy@grenat.iro.umontreal.ca # Date 1209405174 14400 # Node ID 5a85fda9b19b3455d7db050350a1beb18890c85d # Parent e92244f3011640b87c7044f774d831565c0b8480 Fixed some more iterator bugs diff -r e92244f30116 -r 5a85fda9b19b dataset.py --- a/dataset.py Mon Apr 28 11:41:28 2008 -0400 +++ b/dataset.py Mon Apr 28 13:52:54 2008 -0400 @@ -139,11 +139,16 @@ """ def __init__(self, minibatch_iterator): self.minibatch_iterator = minibatch_iterator + self.minibatch = None def __iter__(self): #makes for loop work return self def next(self): size1_minibatch = self.minibatch_iterator.next() - return Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()]) + if not self.minibatch: + self.minibatch = Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()]) + else: + self.minibatch._values = [value[0] for value in size1_minibatch.values()] + return self.minibatch def next_index(self): return self.minibatch_iterator.next_index() @@ -476,16 +481,22 @@ return self.fieldnames def __iter__(self): - class Iterator(object): + class FieldsSubsetIterator(object): def __init__(self,ds): self.ds=ds self.src_iter=ds.src.__iter__() + self.example=None def __iter__(self): return self def next(self): - example = self.src_iter.next() - return Example(self.ds.fieldnames, - [example[field] for field in self.ds.fieldnames]) - return Iterator(self) + complete_example = self.src_iter.next() + if self.example: + self.example._values=[complete_example[field] + for field in self.ds.fieldnames] + else: + self.example=Example(self.ds.fieldnames, + [complete_example[field] for field in self.ds.fieldnames]) + return self.example + return FieldsSubsetIterator(self) def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): assert self.hasFields(*fieldnames) @@ -670,7 +681,7 @@ def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): - class Iterator(object): + class HStackedIterator(object): def __init__(self,hsds,iterators): self.hsds=hsds self.iterators=iterators @@ -700,7 +711,7 @@ else: datasets=self.datasets iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] - return Iterator(self,iterators) + return HStackedIterator(self,iterators) def valuesVStack(self,fieldname,fieldvalues): @@ -768,8 +779,8 @@ return dataset_index, row_within_dataset def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): - - class Iterator(object): + + class VStackedIterator(object): def __init__(self,vsds): self.vsds=vsds self.next_row=offset @@ -824,7 +835,8 @@ self.next_dataset_row+=minibatch_size if self.next_row+minibatch_size>len(dataset): self.move_to_next_dataset() - return + return examples + return VStackedIterator(self) class ArrayFieldsDataSet(DataSet): """ @@ -886,10 +898,11 @@ # """More efficient implementation than the default""" def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): - class Iterator(LookupList): # store the result in the lookup-list values + class ArrayDataSetIterator(object): def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): if fieldnames is None: fieldnames = dataset.fieldNames() - LookupList.__init__(self,fieldnames,[0]*len(fieldnames)) + # store the resulting minibatch in a lookup-list of values + self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) self.dataset=dataset self.minibatch_size=minibatch_size assert offset>=0 and offset