diff dataset.py @ 5:8039918516fe

Added MinibatchIterator
author bengioy@bengiomac.local
date Sun, 23 Mar 2008 22:44:43 -0400
parents f7dcfb5f9d5b
children d5738b79089a
line wrap: on
line diff
--- a/dataset.py	Sun Mar 23 22:14:10 2008 -0400
+++ b/dataset.py	Sun Mar 23 22:44:43 2008 -0400
@@ -56,6 +56,29 @@
         """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
         raise NotImplementedError
 
+    def minibatches(self,minibatch_size):
+        """Return an iterator for the dataset that goes through minibatches of the given size."""
+        return MinibatchIterator(self,minibatch_size)
+
+class MinibatchIterator(object):
+    """
+    Iterator class for FiniteDataSet that can iterate by minibatches
+    (sub-dataset of consecutive examples).
+    """
+    def __init__(self,dataset,minibatch_size):
+        assert minibatch_size>0 and minibatch_size<len(dataset)
+        self.dataset=dataset
+        self.minibatch_size=minibatch_size
+        self.current=-minibatch_size
+    def __iter__(self):
+        return self
+    def next(self):
+        self.current+=self.minibatch_size
+        if self.current>=len(self.dataset):
+            self.current=-self.minibatchsize
+            raise StopIteration
+        return self.dataset[self.current:self.current+self.minibatchsize]
+    
 # we may want ArrayDataSet defined in another python file
 
 import numpy
@@ -197,4 +220,4 @@
                c+=slice_width
             return result
         return self.data
-        
+