# HG changeset patch # User Frederic Bastien # Date 1210096254 14400 # Node ID 7186e4f502d1b684f2b86b2451772a36a3b2e8f8 # Parent 05cfe011ca201921ba780e43df1fb2a62bfd34e9 bugfix in DataSet.minibatch to correctly wrap in all corner case. diff -r 05cfe011ca20 -r 7186e4f502d1 dataset.py --- a/dataset.py Tue May 06 10:53:28 2008 -0400 +++ b/dataset.py Tue May 06 13:50:54 2008 -0400 @@ -209,16 +209,18 @@ self.n_batches=n_batches self.n_batches_done=0 self.next_row=offset + self.offset=offset self.L=len(dataset) assert offset+minibatch_size<=self.L - ds_nbatches = (self.L-offset)/minibatch_size + ds_nbatches = (self.L-self.next_row)/self.minibatch_size if n_batches is not None: - ds_nbatches = max(n_batches,ds_nbatches) + ds_nbatches = min(n_batches,ds_nbatches) if fieldnames: assert dataset.hasFields(*fieldnames) else: - fieldnames=dataset.fieldNames() - self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) + self.fieldnames=dataset.fieldNames() + self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, + ds_nbatches,self.next_row) def __iter__(self): return self @@ -237,8 +239,8 @@ raise StopIteration # we must concatenate (vstack) the bottom and top parts of our minibatch # first get the beginning of our minibatch (top of dataset) - first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next() - second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next() + first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() minibatch = Example(self.fieldnames, [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) for name in self.fieldnames]) @@ -246,6 +248,11 @@ self.n_batches_done+=1 if upper >= self.L and self.n_batches: self.next_row -= self.L + ds_nbatches = (self.L-self.next_row)/self.minibatch_size + if self.n_batches is not None: + ds_nbatches = min(self.n_batches,ds_nbatches) + self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, + ds_nbatches,self.next_row) return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack, self.dataset.valuesHStack), minibatch.keys())