diff dataset.py @ 231:38beb81f4e8b

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 27 May 2008 13:46:03 -0400
parents 17c5d080964b 6f55e301c687
children a70f2c973ea5 ddb88a8e9fd2
line wrap: on
line diff
--- a/dataset.py	Tue May 27 13:23:05 2008 -0400
+++ b/dataset.py	Tue May 27 13:46:03 2008 -0400
@@ -1043,7 +1043,31 @@
         assert key in self.__dict__ # else it means we are trying to access a non-existing property
         return self.__dict__[key]
         
-            
+    def __iter__(self):
+        class ArrayDataSetIterator2(object):
+            def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
+                if fieldnames is None: fieldnames = dataset.fieldNames()
+                # store the resulting minibatch in a lookup-list of values
+                self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
+                self.dataset=dataset
+                self.minibatch_size=minibatch_size
+                assert offset>=0 and offset<len(dataset.data)
+                assert offset+minibatch_size<=len(dataset.data)
+                self.current=offset
+            def __iter__(self):
+                return self
+            def next(self):
+                #@todo: we suppose that we need to stop only when minibatch_size == 1.
+                # Otherwise, MinibatchWrapAroundIterator do it.
+                if self.current>=self.dataset.data.shape[0]:
+                    raise StopIteration
+                sub_data =  self.dataset.data[self.current]
+                self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names]
+                self.current+=self.minibatch_size
+                return self.minibatch
+
+        return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0)
+
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
         class ArrayDataSetIterator(object):
             def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
@@ -1058,6 +1082,7 @@
             def __iter__(self):
                 return self
             def next(self):
+                #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator
                 sub_data =  self.dataset.data[self.current:self.current+self.minibatch_size]
                 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names]
                 self.current+=self.minibatch_size