Mercurial > pylearn
diff dataset.py @ 5:8039918516fe
Added MinibatchIterator
author | bengioy@bengiomac.local |
---|---|
date | Sun, 23 Mar 2008 22:44:43 -0400 |
parents | f7dcfb5f9d5b |
children | d5738b79089a |
line wrap: on
line diff
--- a/dataset.py Sun Mar 23 22:14:10 2008 -0400 +++ b/dataset.py Sun Mar 23 22:44:43 2008 -0400 @@ -56,6 +56,29 @@ """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" raise NotImplementedError + def minibatches(self,minibatch_size): + """Return an iterator for the dataset that goes through minibatches of the given size.""" + return MinibatchIterator(self,minibatch_size) + +class MinibatchIterator(object): + """ + Iterator class for FiniteDataSet that can iterate by minibatches + (sub-dataset of consecutive examples). + """ + def __init__(self,dataset,minibatch_size): + assert minibatch_size>0 and minibatch_size<len(dataset) + self.dataset=dataset + self.minibatch_size=minibatch_size + self.current=-minibatch_size + def __iter__(self): + return self + def next(self): + self.current+=self.minibatch_size + if self.current>=len(self.dataset): + self.current=-self.minibatchsize + raise StopIteration + return self.dataset[self.current:self.current+self.minibatchsize] + # we may want ArrayDataSet defined in another python file import numpy @@ -197,4 +220,4 @@ c+=slice_width return result return self.data - +