Mercurial > pylearn
diff dataset.py @ 98:7186e4f502d1
bugfix in DataSet.minibatch to correctly wrap in all corner case.
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 13:50:54 -0400 |
parents | 6fe972a7393c |
children | a8da709eb6a9 |
line wrap: on
line diff
--- a/dataset.py Tue May 06 10:53:28 2008 -0400 +++ b/dataset.py Tue May 06 13:50:54 2008 -0400 @@ -209,16 +209,18 @@ self.n_batches=n_batches self.n_batches_done=0 self.next_row=offset + self.offset=offset self.L=len(dataset) assert offset+minibatch_size<=self.L - ds_nbatches = (self.L-offset)/minibatch_size + ds_nbatches = (self.L-self.next_row)/self.minibatch_size if n_batches is not None: - ds_nbatches = max(n_batches,ds_nbatches) + ds_nbatches = min(n_batches,ds_nbatches) if fieldnames: assert dataset.hasFields(*fieldnames) else: - fieldnames=dataset.fieldNames() - self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) + self.fieldnames=dataset.fieldNames() + self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, + ds_nbatches,self.next_row) def __iter__(self): return self @@ -237,8 +239,8 @@ raise StopIteration # we must concatenate (vstack) the bottom and top parts of our minibatch # first get the beginning of our minibatch (top of dataset) - first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next() - second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next() + first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() minibatch = Example(self.fieldnames, [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) for name in self.fieldnames]) @@ -246,6 +248,11 @@ self.n_batches_done+=1 if upper >= self.L and self.n_batches: self.next_row -= self.L + ds_nbatches = (self.L-self.next_row)/self.minibatch_size + if self.n_batches is not None: + ds_nbatches = min(self.n_batches,ds_nbatches) + self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, + ds_nbatches,self.next_row) return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack, self.dataset.valuesHStack), minibatch.keys())