diff dataset.py @ 101:a1740a99b81f

by default, in a minibatch without any fixed number of batchs, we need to finish at the end of the dataset. Now we return a minibatch at the end event if this minibacht size != the gived minibatch_size.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 16:01:53 -0400
parents a8da709eb6a9
children 8c0a1b11b007
line wrap: on
line diff
--- a/dataset.py	Tue May 06 14:01:22 2008 -0400
+++ b/dataset.py	Tue May 06 16:01:53 2008 -0400
@@ -231,19 +231,26 @@
         def next(self):
             if self.n_batches and self.n_batches_done==self.n_batches:
                 raise StopIteration
+            elif not self.n_batches and self.next_row ==self.L:
+                raise StopIteration
             upper = self.next_row+self.minibatch_size
             if upper <=self.L:
                 minibatch = self.iterator.next()
             else:
                 if not self.n_batches:
-                    raise StopIteration
-                # we must concatenate (vstack) the bottom and top parts of our minibatch
-                # first get the beginning of our minibatch (top of dataset)
-                first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
-                second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
-                minibatch = Example(self.fieldnames,
-                                    [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
-                                     for name in self.fieldnames])
+                    upper=min(upper, self.L)
+                    # if their is not a fixed number of batch, we continue to the end of the dataset.
+                    # this can create a minibatch that is smaller then the minibatch_size
+                    assert (self.L-self.next_row)<=self.minibatch_size
+                    minibatch = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
+                else:
+                    # we must concatenate (vstack) the bottom and top parts of our minibatch
+                    # first get the beginning of our minibatch (top of dataset)
+                    first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
+                    second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
+                    minibatch = Example(self.fieldnames,
+                                        [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
+                                         for name in self.fieldnames])
             self.next_row=upper
             self.n_batches_done+=1
             if upper >= self.L and self.n_batches: