diff dataset.py @ 7:6f8f338686db

Moved iterating counter into a FiniteDataSetIterator to allow embedded iterations and multiple threads iterating at the same time on a dataset.
author bengioy@bengiomac.local
date Mon, 24 Mar 2008 13:20:15 -0400
parents d5738b79089a
children d1c394486037
line wrap: on
line diff
--- a/dataset.py	Mon Mar 24 09:04:06 2008 -0400
+++ b/dataset.py	Mon Mar 24 13:20:15 2008 -0400
@@ -20,14 +20,11 @@
         self.minibatch_size=minibatch_size
 
     def __iter__(self):
-        return self
-
-    def next(self):
         """
-        Return the next example or the next minibatch in the dataset.
-        A minibatch (of length > 1) should be something one can iterate on again in order
-        to obtain the individual examples. If the dataset has fields,
-        then the example or the minibatch must have the same fields
+        Return an iterator, whose next() method returns the next example or the next 
+        minibatch in the dataset. A minibatch (of length > 1) should be something one 
+        can iterate on again in order to obtain the individual examples. If the dataset 
+        has fields, then the example or the minibatch must have the same fields
         (typically this is implemented by returning another (small) dataset, when
         there are fields).
         """
@@ -55,6 +52,9 @@
     def __init__(self,minibatch_size):
         DataSet.__init__(self,minibatch_size)
 
+    def __iter__(self):
+        return FiniteDataSetIterator(self)
+    
     def __len__(self):
         """len(dataset) returns the number of examples in the dataset."""
         raise NotImplementedError
@@ -67,6 +67,35 @@
         """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
         raise NotImplementedError
 
+class FiniteDataSetIterator(object):
+    def __init__(self,dataset):
+        self.dataset=dataset
+        self.current = -self.dataset.minibatch_size
+        
+    def next(self):
+        """
+        Return the next example(s) in the dataset. If self.dataset.minibatch_size>1 return that
+        many examples. If the dataset has fields, the example or the minibatch of examples
+        is just a minibatch_size-rows ArrayDataSet (so that the fields can be accessed),
+        but that resulting mini-dataset has a minibatch_size of 1, so that one can iterate
+        example-wise on it. On the other hand, if the dataset has no fields (e.g. because
+        it is already the field of a bigger dataset), then the returned example or minibatch
+        may be any indexable object, such as a numpy array. Following the array semantics of indexing
+        and slicing, if the minibatch_size is 1 (and there are no fields), then the result is an array
+        with one less dimension (e.g., a vector, if the dataset is a matrix), corresponding
+        to a row. Again, if the minibatch_size is >1, one can iterate on the result to
+        obtain individual examples (as rows).
+        """
+        self.current+=self.dataset.minibatch_size
+        if self.current>=len(self.dataset):
+            self.current=-self.dataset.minibatch_size
+            raise StopIteration
+        if self.dataset.minibatch_size==1:
+            return self.dataset[self.current]
+        else:
+            return self.dataset[self.current:self.current+self.dataset.minibatch_size]
+
+
 # we may want ArrayDataSet defined in another python file
 
 import numpy
@@ -88,7 +117,6 @@
         a dictionary of column slices indexed by field names).
         """
         FiniteDataSet.__init__(self,minibatch_size)
-        self.current_row=-1 # used for view of this dataset as an iterator
         if dataset!=None:
             assert data==None and fields=={}
             # convert dataset to an ArrayDataSet
@@ -108,43 +136,20 @@
                 if not step:
                     step=1
                 if not fieldslice.start or not fieldslice.step:
-                    fieldslice = slice(start,fieldslice.stop,step)
+                    fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step)
                 # and coherent with the data array
                 assert fieldslice.start>=0 and fieldslice.stop<=self.width
         assert minibatch_size<=len(self.data)
 
-    def next(self):
-        """
-        Return the next example(s) in the dataset. If self.minibatch_size>1 return that
-        many examples. If the dataset has fields, the example or the minibatch of examples
-        is just a minibatch_size-rows ArrayDataSet (so that the fields can be accessed),
-        but that resulting mini-dataset has a minibatch_size of 1, so that one can iterate
-        example-wise on it. On the other hand, if the dataset has no fields (e.g. because
-        it is already the field of a bigger dataset), then the returned example or minibatch
-        is a numpy array. Following the array semantics of indexing and slicing,
-        if the minibatch_size is 1 (and there are no fields), then the result is an array
-        with one less dimension (e.g., a vector, if the dataset is a matrix), corresponding
-        to a row. Again, if the minibatch_size is >1, one can iterate on the result to
-        obtain individual examples (as rows).
+    def __getattr__(self,fieldname):
         """
-        if self.fields:
-            self.current_row+=self.minibatch_size
-            if self.current_row>=len(self.data):
-                self.current_row=-self.minibatch_size
-                raise StopIteration
-            if self.minibatch_size==1:
-                return self[self.current_row]
-            else:
-                return self[self.current_row:self.current_row+self.minibatch_size]
-        else:
-            if self.minibatch_size==1:
-                return self.data[self.current_row]
-            else:
-                return self.data[self.current_row:self.current_row+self.minibatch_size]
-
-    def __getattr__(self,fieldname):
-        """Return a numpy array with the content associated with the given field name."""
-        return self.data[self.fields[fieldname]]
+        Return a numpy array with the content associated with the given field name.
+        If this is a one-example dataset, then a row, i.e., numpy array (of one less dimension
+        than the dataset.data) is returned.
+        """
+        if len(self.data)==1:
+            return self.data[0,self.fields[fieldname]]
+        return self.data[:,self.fields[fieldname]]
 
     def __call__(self,*fieldnames):
         """Return a sub-dataset containing only the given fieldnames as fields."""
@@ -176,49 +181,51 @@
                 return ArrayDataSet(data=data[slice],fields=self.fields)
             return ArrayDataSet(data=self.data[i:i+1],fields=self.fields)
         else:
-            return data[i]
+            return self.data[i]
 
     def __getslice__(self,*slice_args):
         """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
         return ArrayDataSet(data=self.data[apply(slice,slice_args)],fields=self.fields)
 
     def asarray(self):
-        if self.fields:
-            columns_used = numpy.zeros((self.data.shape[1]),dtype=bool)
-            for field_slice in self.fields.values():
-                for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
-                    columns_used[c]=True
-            # try to figure out if we can map all the slices into one slice:
-            mappable_to_one_slice = True
-            start=0
-            while start<len(columns_used) and not columns_used[start]:
-                start+=1
-            stop=len(columns_used)
-            while stop>0 and not columns_used[stop-1]:
-                stop-=1
-            step=0
-            i=start
-            while i<stop:
-                j=i+1
-                while not columns_used[j] and j<stop:
-                    j+=1
-                if step:
-                    if step!=j-i:
-                        mappable_to_one_slice = False
-                        break
-                else:
-                    step = j-i
-            if mappable_to_one_slice:
-                return data[slice(start,stop,step)]
-            # else make contiguous copy
-            n_columns = sum(columns_used)
-            result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
-            c=0
-            for field_slice in self.fields.values():
-               slice_width=field_slice.stop-field_slice.start/field_slice.step
-               # copy the field here
-               result[:,slice(c,slice_width)]=self.data[field_slice]
-               c+=slice_width
-            return result
-        return self.data
-
+        if not self.fields:
+            return self.data
+        # else, select subsets of columns mapped by the fields
+        columns_used = numpy.zeros((self.data.shape[1]),dtype=bool)
+        for field_slice in self.fields.values():
+            for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
+                columns_used[c]=True
+        # try to figure out if we can map all the slices into one slice:
+        mappable_to_one_slice = True
+        start=0
+        while start<len(columns_used) and not columns_used[start]:
+            start+=1
+        stop=len(columns_used)
+        while stop>0 and not columns_used[stop-1]:
+            stop-=1
+        step=0
+        i=start
+        while i<stop:
+            j=i+1
+            while j<stop and not columns_used[j]:
+                j+=1
+            if step:
+                if step!=j-i:
+                    mappable_to_one_slice = False
+                    break
+            else:
+                step = j-i
+            i=j
+        if mappable_to_one_slice:
+            return self.data[:,slice(start,stop,step)]
+        # else make contiguous copy
+        n_columns = sum(columns_used)
+        result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
+        print result.shape
+        c=0
+        for field_slice in self.fields.values():
+            slice_width=field_slice.stop-field_slice.start/field_slice.step
+            # copy the field here
+            result[:,slice(c,slice_width)]=self.data[:,field_slice]
+            c+=slice_width
+        return result