diff dataset.py @ 6:d5738b79089a

Removed MinibatchIterator and instead made minibatch_size a field of all DataSets, so that they can all iterate over minibatches, optionally.
author bengioy@bengiomac.local
date Mon, 24 Mar 2008 09:04:06 -0400
parents 8039918516fe
children 6f8f338686db
line wrap: on
line diff
--- a/dataset.py	Sun Mar 23 22:44:43 2008 -0400
+++ b/dataset.py	Mon Mar 24 09:04:06 2008 -0400
@@ -6,20 +6,31 @@
     A dataset is basically an iterator over examples. It does not necessarily
     have a fixed length (this is useful for 'streams' which feed on-line learning).
     Datasets with fixed and known length are FiniteDataSet, a subclass of DataSet.
-    Examples and datasets have named fields. 
+    Examples and datasets optionally have named fields. 
     One can obtain a sub-dataset by taking dataset.field or dataset(field1,field2,field3,...).
     Fields are not mutually exclusive, i.e. two fields can overlap in their actual content.
-    The content of a field can be of any type, but often will be a numpy tensor.
+    The content of a field can be of any type, but often will be a numpy array.
+    The minibatch_size field, if different than 1, means that the iterator (next() method)
+    returns not a single example but an array of length minibatch_size, i.e., an indexable
+    object.
     """
 
-    def __init__(self):
-        pass
+    def __init__(self,minibatch_size=1):
+        assert minibatch_size>0
+        self.minibatch_size=minibatch_size
 
     def __iter__(self):
         return self
 
     def next(self):
-        """Return the next example in the dataset."""
+        """
+        Return the next example or the next minibatch in the dataset.
+        A minibatch (of length > 1) should be something one can iterate on again in order
+        to obtain the individual examples. If the dataset has fields,
+        then the example or the minibatch must have the same fields
+        (typically this is implemented by returning another (small) dataset, when
+        there are fields).
+        """
         raise NotImplementedError
 
     def __getattr__(self,fieldname):
@@ -41,8 +52,8 @@
     and a subdataset can be obtained by slicing.
     """
 
-    def __init__(self):
-        pass
+    def __init__(self,minibatch_size):
+        DataSet.__init__(self,minibatch_size)
 
     def __len__(self):
         """len(dataset) returns the number of examples in the dataset."""
@@ -56,49 +67,27 @@
         """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
         raise NotImplementedError
 
-    def minibatches(self,minibatch_size):
-        """Return an iterator for the dataset that goes through minibatches of the given size."""
-        return MinibatchIterator(self,minibatch_size)
-
-class MinibatchIterator(object):
-    """
-    Iterator class for FiniteDataSet that can iterate by minibatches
-    (sub-dataset of consecutive examples).
-    """
-    def __init__(self,dataset,minibatch_size):
-        assert minibatch_size>0 and minibatch_size<len(dataset)
-        self.dataset=dataset
-        self.minibatch_size=minibatch_size
-        self.current=-minibatch_size
-    def __iter__(self):
-        return self
-    def next(self):
-        self.current+=self.minibatch_size
-        if self.current>=len(self.dataset):
-            self.current=-self.minibatchsize
-            raise StopIteration
-        return self.dataset[self.current:self.current+self.minibatchsize]
-    
 # we may want ArrayDataSet defined in another python file
 
 import numpy
 
 class ArrayDataSet(FiniteDataSet):
     """
-    A fixed-length and fixed-width dataset in which each element is a numpy.array
-    or a number, hence the whole dataset corresponds to a numpy.array. Fields
+    A fixed-length and fixed-width dataset in which each element is a numpy array
+    or a number, hence the whole dataset corresponds to a numpy array. Fields
     must correspond to a slice of columns. If the dataset has fields,
-    each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.
-    Any dataset can also be converted to a numpy.array (losing the notion of fields)
+    each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array.
+    Any dataset can also be converted to a numpy array (losing the notion of fields)
     by the asarray(dataset) call.
     """
 
-    def __init__(self,dataset=None,data=None,fields={}):
+    def __init__(self,dataset=None,data=None,fields={},minibatch_size=1):
         """
         Construct an ArrayDataSet, either from a DataSet, or from
-        a numpy.array plus an optional specification of fields (by
+        a numpy array plus an optional specification of fields (by
         a dictionary of column slices indexed by field names).
         """
+        FiniteDataSet.__init__(self,minibatch_size)
         self.current_row=-1 # used for view of this dataset as an iterator
         if dataset!=None:
             assert data==None and fields=={}
@@ -122,28 +111,40 @@
                     fieldslice = slice(start,fieldslice.stop,step)
                 # and coherent with the data array
                 assert fieldslice.start>=0 and fieldslice.stop<=self.width
+        assert minibatch_size<=len(self.data)
 
     def next(self):
         """
-        Return the next example in the dataset. If the dataset has fields,
-        the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.
+        Return the next example(s) in the dataset. If self.minibatch_size>1 return that
+        many examples. If the dataset has fields, the example or the minibatch of examples
+        is just a minibatch_size-rows ArrayDataSet (so that the fields can be accessed),
+        but that resulting mini-dataset has a minibatch_size of 1, so that one can iterate
+        example-wise on it. On the other hand, if the dataset has no fields (e.g. because
+        it is already the field of a bigger dataset), then the returned example or minibatch
+        is a numpy array. Following the array semantics of indexing and slicing,
+        if the minibatch_size is 1 (and there are no fields), then the result is an array
+        with one less dimension (e.g., a vector, if the dataset is a matrix), corresponding
+        to a row. Again, if the minibatch_size is >1, one can iterate on the result to
+        obtain individual examples (as rows).
         """
         if self.fields:
-            self.current_row+=1
-            if self.current_row==len(self.data):
-                self.current_row=-1
+            self.current_row+=self.minibatch_size
+            if self.current_row>=len(self.data):
+                self.current_row=-self.minibatch_size
                 raise StopIteration
-            return self[self.current_row]
+            if self.minibatch_size==1:
+                return self[self.current_row]
+            else:
+                return self[self.current_row:self.current_row+self.minibatch_size]
         else:
-            return self.data[self.current_row]
+            if self.minibatch_size==1:
+                return self.data[self.current_row]
+            else:
+                return self.data[self.current_row:self.current_row+self.minibatch_size]
 
     def __getattr__(self,fieldname):
-        """Return a sub-dataset containing only the given fieldname as field."""
-        data=self.data[self.fields[fieldname]]                
-        if len(data)==1:
-            return data
-        else:
-            return ArrayDataSet(data=data)
+        """Return a numpy array with the content associated with the given field name."""
+        return self.data[self.fields[fieldname]]
 
     def __call__(self,*fieldnames):
         """Return a sub-dataset containing only the given fieldnames as fields."""
@@ -155,7 +156,7 @@
         new_fields={}
         for field in self.fields:
             new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step)
-        return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields)
+        return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields,minibatch_size=self.minibatch_size)
 
     def fieldNames(self):
         """Return the list of field names that are supported by getattr and getFields."""
@@ -179,7 +180,7 @@
 
     def __getslice__(self,*slice_args):
         """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
-        return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields)
+        return ArrayDataSet(data=self.data[apply(slice,slice_args)],fields=self.fields)
 
     def asarray(self):
         if self.fields: