Mercurial > pylearn
diff dataset.py @ 101:a1740a99b81f
by default, in a minibatch without any fixed number of batchs, we need to finish at the end of the dataset. Now we return a minibatch at the end event if this minibacht size != the gived minibatch_size.
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 16:01:53 -0400 |
parents | a8da709eb6a9 |
children | 8c0a1b11b007 |
line wrap: on
line diff
--- a/dataset.py Tue May 06 14:01:22 2008 -0400 +++ b/dataset.py Tue May 06 16:01:53 2008 -0400 @@ -231,19 +231,26 @@ def next(self): if self.n_batches and self.n_batches_done==self.n_batches: raise StopIteration + elif not self.n_batches and self.next_row ==self.L: + raise StopIteration upper = self.next_row+self.minibatch_size if upper <=self.L: minibatch = self.iterator.next() else: if not self.n_batches: - raise StopIteration - # we must concatenate (vstack) the bottom and top parts of our minibatch - # first get the beginning of our minibatch (top of dataset) - first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() - second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() - minibatch = Example(self.fieldnames, - [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) - for name in self.fieldnames]) + upper=min(upper, self.L) + # if their is not a fixed number of batch, we continue to the end of the dataset. + # this can create a minibatch that is smaller then the minibatch_size + assert (self.L-self.next_row)<=self.minibatch_size + minibatch = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + else: + # we must concatenate (vstack) the bottom and top parts of our minibatch + # first get the beginning of our minibatch (top of dataset) + first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() + minibatch = Example(self.fieldnames, + [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) + for name in self.fieldnames]) self.next_row=upper self.n_batches_done+=1 if upper >= self.L and self.n_batches: