# HG changeset patch # User Frederic Bastien # Date 1210104113 14400 # Node ID a1740a99b81fd955098fa500bf14266ab1ec1060 # Parent 574f4db76022914217f6d8b70322a2ca1bafc53e by default, in a minibatch without any fixed number of batchs, we need to finish at the end of the dataset. Now we return a minibatch at the end event if this minibacht size != the gived minibatch_size. diff -r 574f4db76022 -r a1740a99b81f dataset.py --- a/dataset.py Tue May 06 14:01:22 2008 -0400 +++ b/dataset.py Tue May 06 16:01:53 2008 -0400 @@ -231,19 +231,26 @@ def next(self): if self.n_batches and self.n_batches_done==self.n_batches: raise StopIteration + elif not self.n_batches and self.next_row ==self.L: + raise StopIteration upper = self.next_row+self.minibatch_size if upper <=self.L: minibatch = self.iterator.next() else: if not self.n_batches: - raise StopIteration - # we must concatenate (vstack) the bottom and top parts of our minibatch - # first get the beginning of our minibatch (top of dataset) - first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() - second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() - minibatch = Example(self.fieldnames, - [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) - for name in self.fieldnames]) + upper=min(upper, self.L) + # if their is not a fixed number of batch, we continue to the end of the dataset. + # this can create a minibatch that is smaller then the minibatch_size + assert (self.L-self.next_row)<=self.minibatch_size + minibatch = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + else: + # we must concatenate (vstack) the bottom and top parts of our minibatch + # first get the beginning of our minibatch (top of dataset) + first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() + minibatch = Example(self.fieldnames, + [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) + for name in self.fieldnames]) self.next_row=upper self.n_batches_done+=1 if upper >= self.L and self.n_batches: