diff dataset.py @ 98:7186e4f502d1

bugfix in DataSet.minibatch to correctly wrap in all corner case.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 13:50:54 -0400
parents 6fe972a7393c
children a8da709eb6a9
line wrap: on
line diff
--- a/dataset.py	Tue May 06 10:53:28 2008 -0400
+++ b/dataset.py	Tue May 06 13:50:54 2008 -0400
@@ -209,16 +209,18 @@
             self.n_batches=n_batches
             self.n_batches_done=0
             self.next_row=offset
+            self.offset=offset
             self.L=len(dataset)
             assert offset+minibatch_size<=self.L
-            ds_nbatches = (self.L-offset)/minibatch_size
+            ds_nbatches =  (self.L-self.next_row)/self.minibatch_size
             if n_batches is not None:
-                ds_nbatches = max(n_batches,ds_nbatches)
+                ds_nbatches = min(n_batches,ds_nbatches)
             if fieldnames:
                 assert dataset.hasFields(*fieldnames)
             else:
-                fieldnames=dataset.fieldNames()
-            self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset)
+                self.fieldnames=dataset.fieldNames()
+            self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size,
+                                                            ds_nbatches,self.next_row)
 
         def __iter__(self):
             return self
@@ -237,8 +239,8 @@
                     raise StopIteration
                 # we must concatenate (vstack) the bottom and top parts of our minibatch
                 # first get the beginning of our minibatch (top of dataset)
-                first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next()
-                second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next()
+                first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
+                second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
                 minibatch = Example(self.fieldnames,
                                     [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
                                      for name in self.fieldnames])
@@ -246,6 +248,11 @@
             self.n_batches_done+=1
             if upper >= self.L and self.n_batches:
                 self.next_row -= self.L
+                ds_nbatches =  (self.L-self.next_row)/self.minibatch_size
+                if self.n_batches is not None:
+                    ds_nbatches = min(self.n_batches,ds_nbatches)
+                self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size,
+                                                                ds_nbatches,self.next_row)
             return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack,
                                                   self.dataset.valuesHStack),
                                  minibatch.keys())