diff dataset.py @ 28:541a273bc89f

Removed __array__ method from dataset, whose semantics did not have a clear use (because of the possibility of overlapping fields).
author bengioy@grenat.iro.umontreal.ca
date Fri, 11 Apr 2008 13:08:51 -0400
parents 672fe4b23032
children 46c5c90019c2
line wrap: on
line diff
--- a/dataset.py	Fri Apr 11 11:16:09 2008 -0400
+++ b/dataset.py	Fri Apr 11 13:08:51 2008 -0400
@@ -37,8 +37,8 @@
 
     Datasets of finite length should be sub-classes of FiniteLengthDataSet.
 
-    Datasets whose elements can be indexed and sub-datasets of consecutive
-    examples (i.e. slices) can be extracted from should be sub-classes of
+    Datasets whose elements can be indexed and whose sub-datasets (with a subset
+    of examples) can be extracted should be sub-classes of
     SliceableDataSet.
 
     Datasets with a finite number of fields should be sub-classes of
@@ -230,8 +230,10 @@
     Virtual interface, a subclass of DataSet for datasets which are sliceable
     and whose individual elements can be accessed, generally respecting the
     python semantics for [spec], where spec is either a non-negative integer
-    (for selecting one example), or a python slice (for selecting a sub-dataset
-    comprising the specified examples). This is useful for obtaining
+    (for selecting one example), a python slice(start,stop,step) for selecting a regular
+    sub-dataset comprising examples start,start+step,start+2*step,...,n (with n<stop), or a
+    sequence (e.g. a list) of integers [i1,i2,...,in] for selecting
+    an arbitrary subset of examples. This is useful for obtaining
     sub-datasets, e.g. for splitting a dataset into training and test sets.
     """
     def __init__(self):
@@ -250,11 +252,19 @@
         return DataSet.Iterator(self, fieldnames, minibatch_size, n_batches)
 
     def __getitem__(self,i):
-        """dataset[i] returns the (i+1)-th example of the dataset."""
+        """
+        dataset[i] returns the (i+1)-th example of the dataset.
+        dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.
+        dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2.
+        dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in.
+        """
         raise AbstractFunction()
 
     def __getslice__(self,*slice_args):
-        """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
+        """
+        dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.
+        dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2.
+        """
         raise AbstractFunction()
 
 
@@ -348,7 +358,8 @@
     It is a  fixed-length and fixed-width dataset 
     in which each element is a fixed dimension numpy array or a number, hence the whole 
     dataset corresponds to a numpy array. Fields
-    must correspond to a slice of array columns. If the dataset has fields,
+    must correspond to a slice of array columns or to a list of column numbers.
+    If the dataset has fields,
     each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array.
     Any dataset can also be converted to a numpy array (losing the notion of fields
     by the numpy.array(dataset) call.
@@ -396,7 +407,7 @@
             if self.next_count == self.next_max:
                 raise StopIteration
 
-            #determine the first and last elements of the slice we'll return
+            #determine the first and last elements of the minibatch slice we'll return
             n_rows = self.dataset.data.shape[0]
             self.current = self.next_index()
             upper = self.current + self.minibatch_size
@@ -423,7 +434,7 @@
         There are two ways to construct an ArrayDataSet: (1) from an
         existing dataset (which may result in a copy of the data in a numpy array),
         or (2) from a numpy.array (the data argument), along with an optional description
-        of the fields (a LookupList of column slices indexed by field names).
+        of the fields (a LookupList of column slices (or column lists) indexed by field names).
         """
         self.data=data
         self.fields=fields
@@ -431,17 +442,22 @@
 
         if fields:
             for fieldname,fieldslice in fields.items():
-                # make sure fieldslice.start and fieldslice.step are defined
-                start=fieldslice.start
-                step=fieldslice.step
-                if not start:
-                    start=0
-                if not step:
-                    step=1
-                if not fieldslice.start or not fieldslice.step:
-                    fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step)
-                # and coherent with the data array
-                assert fieldslice.start >= 0 and fieldslice.stop <= cols
+                assert type(fieldslice) is int or isinstance(fieldslice,slice) or hasattr(fieldslice,"__iter__")
+                if hasattr(fieldslice,"__iter__"): # is a sequence
+                    for i in fieldslice:
+                        assert type(i) is int
+                elif isinstance(fieldslice,slice):
+                    # make sure fieldslice.start and fieldslice.step are defined
+                    start=fieldslice.start
+                    step=fieldslice.step
+                    if not start:
+                        start=0
+                    if not step:
+                        step=1
+                    if not fieldslice.start or not fieldslice.step:
+                        fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step)
+                    # and coherent with the data array
+                    assert fieldslice.start >= 0 and fieldslice.stop <= cols
 
     def minibatches(self,
             fieldnames = DataSet.minibatches_fieldnames,
@@ -469,15 +485,7 @@
 
     def __call__(self,*fieldnames):
         """Return a sub-dataset containing only the given fieldnames as fields."""
-        min_col=self.data.shape[1]
-        max_col=0
-        for field_slice in self.fields.values():
-            min_col=min(min_col,field_slice.start)
-            max_col=max(max_col,field_slice.stop)
-        new_fields=LookupList()
-        for fieldname,fieldslice in self.fields.items():
-            new_fields[fieldname]=slice(fieldslice.start-min_col,fieldslice.stop-min_col,fieldslice.step)
-        return ArrayDataSet(self.data[:,min_col:max_col],fields=new_fields)
+        return ArrayDataSet(self.data,fields=LookupList(fieldnames,[self.fields[fieldname] for fieldname in fieldnames]))
 
     def fieldNames(self):
         """Return the list of field names that are supported by getattr and hasField."""
@@ -489,8 +497,11 @@
     
     def __getitem__(self,i):
         """
-        dataset[i] returns the (i+1)-th Example of the dataset. If there are no fields
-        the result is just a numpy array (for the i-th row of the dataset data matrix).
+        dataset[i] returns the (i+1)-th Example of the dataset.
+        If there are no fields the result is just a numpy array (for the i-th row of the dataset data matrix).
+        dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.
+        dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2.
+        dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in.
         """
         if self.fields:
             fieldnames,fieldslices=zip(*self.fields.items())
@@ -499,36 +510,34 @@
             return self.data[i]
 
     def __getslice__(self,*args):
-        """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
+        """
+        dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.
+        dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2.
+        """
         return ArrayDataSet(self.data.__getslice__(*args), fields=self.fields)
 
-    def __array__(self):
-        """Return a view of this dataset which is an numpy.ndarray (i.e. losing
-        the identity and name of fields within the dataset).
-
-        Numpy uses this special function name to retrieve an ndarray view for
-        function such as numpy.sum, numpy.dot, numpy.asarray, etc.
-
-        If this dataset has no fields, then we simply return self.data,
-        otherwise things are complicated. 
-        - why do we want this behaviour when there are fields? (JB)
-        - for convenience and completeness (but maybe it would make
-          more sense to implement this through a 'field-merging'
-          dataset). (YB)
+    def indices_of_unique_columns_used(self):
+        """
+        Return the unique indices of the columns actually used by the fields, and a boolean
+        that signals (if True) that used columns overlap. If they do then the
+        indices are not repeated in the result.
         """
-        if not self.fields:
-            return self.data
-        # else, select subsets of columns mapped by the fields
         columns_used = numpy.zeros((self.data.shape[1]),dtype=bool)
-        overlapping_fields = False
-        n_columns = 0
+        overlapping_columns = False
         for field_slice in self.fields.values():
-            for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
-                n_columns += 1
-                if columns_used[c]: overlapping_fields=True
-                columns_used[c]=True
-        # try to figure out if we can map all the slices into one slice:
-        mappable_to_one_slice = not overlapping_fields
+            if sum(columns_used[field_slice])>0: overlapping_columns=True
+            columns_used[field_slice]=True
+        return [i for i,used in enumerate(columns_used) if used],overlapping_columns
+
+    def slice_of_unique_columns_used(self):
+        """
+        Return None if the indices_of_unique_columns_used do not form a slice. If they do,
+        return that slice. It means that the columns used can be extracted
+        from the data array without making a copy. If the fields overlap
+        but their unique columns used form a slice, still return that slice.
+        """
+        columns_used,overlapping_columns = self.indices_of_columns_used()
+        mappable_to_one_slice = True
         if not overlapping_fields:
             start=0
             while start<len(columns_used) and not columns_used[start]:
@@ -549,18 +558,8 @@
                 else:
                     step = j-i
                 i=j
-        if mappable_to_one_slice:
-            return self.data[:,slice(start,stop,step)]
-        # else make contiguous copy (copying the overlapping columns)
-        result = numpy.zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
-        c=0
-        for field_slice in self.fields.values():
-            slice_width=(field_slice.stop-field_slice.start)/field_slice.step
-            # copy the field here
-            result[:,slice(c,c+slice_width)]=self.data[:,field_slice]
-            c+=slice_width
-        return result
-
+        return slice(start,stop,step)
+    
 class ApplyFunctionDataSet(DataSet):
     """
     A dataset that contains as fields the results of applying