changeset 272:6226ebafefc3

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 03 Jun 2008 16:13:42 -0400
parents 19b14afe04b7 (current diff) 38e7d90a1218 (diff)
children fa8abc813bd2
files dataset.py
diffstat 1 files changed, 7 insertions(+), 9 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Tue Jun 03 16:06:21 2008 -0400
+++ b/dataset.py	Tue Jun 03 16:13:42 2008 -0400
@@ -1055,32 +1055,30 @@
         return self.__dict__[key]
         
     def __iter__(self):
-        class ArrayDataSetIterator2(object):
-            def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
+        class ArrayDataSetIteratorIter(object):
+            def __init__(self,dataset,fieldnames):
                 if fieldnames is None: fieldnames = dataset.fieldNames()
                 # store the resulting minibatch in a lookup-list of values
                 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
                 self.dataset=dataset
-                self.minibatch_size=minibatch_size
-                assert offset>=0 and offset<len(dataset.data)
-                assert offset+minibatch_size<=len(dataset.data)
-                self.current=offset
+                self.current=0
                 self.columns = [self.dataset.fields_columns[f] 
                                 for f in self.minibatch._names]
+                self.l = self.dataset.data.shape[0]
             def __iter__(self):
                 return self
             def next(self):
                 #@todo: we suppose that we need to stop only when minibatch_size == 1.
                 # Otherwise, MinibatchWrapAroundIterator do it.
-                if self.current>=self.dataset.data.shape[0]:
+                if self.current>=self.l:
                     raise StopIteration
                 sub_data =  self.dataset.data[self.current]
                 self.minibatch._values = [sub_data[c] for c in self.columns]
 
-                self.current+=self.minibatch_size
+                self.current+=1
                 return self.minibatch
 
-        return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0)
+        return ArrayDataSetIteratorIter(self,self.fieldNames())
 
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
         class ArrayDataSetIterator(object):