changeset 41:283e95c15b47

Added ArrayDataSet
author bengioy@grenat.iro.umontreal.ca
date Fri, 25 Apr 2008 12:04:55 -0400
parents 88fd1cce08b9
children 9b68774fcc6b
files dataset.py
diffstat 1 files changed, 125 insertions(+), 23 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Fri Apr 25 10:41:19 2008 -0400
+++ b/dataset.py	Fri Apr 25 12:04:55 2008 -0400
@@ -3,6 +3,7 @@
 Example = LookupList
 from misc import *
 import copy
+import string
 
 class AbstractFunction (Exception): """Derived class must override this function"""
 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented"""
@@ -75,6 +76,13 @@
 
     * dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in.
 
+    * dataset['key'] returns a property associated with the given 'key' string.
+      If 'key' is a fieldname, then the VStacked field values (iterable over
+      field values) for that field is returned. Other keys may be supported
+      by different dataset subclasses. The following key names are should be supported:
+          - 'description': a textual description or name for the dataset
+          - '<fieldname>.type': a type name or value for a given <fieldname>
+
     Datasets can be concatenated either vertically (increasing the length) or
     horizontally (augmenting the set of fields), if they are compatible, using
     the following operations (with the same basic semantics as numpy.hstack
@@ -96,6 +104,11 @@
     a DataSetFields fields1 and fields2, and fields1 | fields2 concatenates their
     examples.
 
+    A dataset can hold arbitrary key-value pairs that may be used to access meta-data
+    or other properties of the dataset or associated with the dataset or the result
+    of a computation stored in a dataset. These can be accessed through the [key] syntax
+    when key is a string (or more specifically, neither an integer, a slice, nor a list).
+    
     A DataSet sub-class should always redefine the following methods:
       * __len__ if it is not a stream
       * fieldNames
@@ -108,8 +121,12 @@
       * __iter__
     """
 
-    def __init__(self):
-        pass
+    def __init__(self,description=None,field_types=None):
+        if description is None:
+            # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)"
+            description = type(self).__name__ + " ( " + string.join([x.__name__ for x in type(self).__bases__]) + " )"
+        self.description=description
+        self.field_types=field_types
     
     class MinibatchToSingleExampleIterator(object):
         """
@@ -320,6 +337,12 @@
         dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.
         dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2.
         dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in.
+        dataset['key'] returns a property associated with the given 'key' string.
+        If 'key' is a fieldname, then the VStacked field values (iterable over
+        field values) for that field is returned. Other keys may be supported
+        by different dataset subclasses. The following key names are encouraged:
+          - 'description': a textual description or name for the dataset
+          - '<fieldname>.type': a type name or value for a given <fieldname>
 
         Note that some stream datasets may be unable to implement random access, i.e.
         arbitrary slicing/indexing
@@ -331,23 +354,33 @@
         always be the most efficient way to obtain the result, especially if
         the data are actually stored in a memory array.
         """
+        # check for an index
         if type(i) is int:
             return DataSet.MinibatchToSingleExampleIterator(
                 self.minibatches(minibatch_size=1,n_batches=1,offset=i)).next()
+        rows=None
+        # or a slice
         if type(i) is slice:
             if not i.start: i.start=0
             if not i.step: i.step=1
             if i.step is 1:
                 return self.minibatches(minibatch_size=i.stop-i.start,n_batches=1,offset=i.start).next().examples()
             rows = range(i.start,i.stop,i.step)
-        else:
-            assert type(i) is list
+        # or a list of indices
+        elif type(i) is list:
             rows = i
-        fields_values = zip(*[self[row] for row in rows])
-        return MinibatchDataSet(
-            Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
-                                        for fieldname,field_values
-                                        in zip(self.fieldNames(),fields_values)]))
+        if rows is not None:
+            fields_values = zip(*[self[row] for row in rows])
+            return MinibatchDataSet(
+                Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
+                                            for fieldname,field_values
+                                            in zip(self.fieldNames(),fields_values)]))
+        # else check for a fieldname
+        if self.hasFields(i):
+            return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0]
+        # else we are trying to access a property of the dataset
+        assert i in self.__dict__ # else it means we are trying to access a non-existing property
+        return self.__dict__[i]
 
     def valuesHStack(self,fieldnames,fieldvalues):
         """
@@ -461,7 +494,7 @@
         assert dataset.hasFields(*fieldnames)
         LookupList.__init__(self,dataset.fieldNames(),
                             dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),
-                                                minibatch_size=len(dataset)).next()
+                                                minibatch_size=len(dataset)).next())
     def examples(self):
         return self.dataset
     
@@ -522,7 +555,7 @@
                 self.ds=ds
                 self.next_example=offset
                 assert minibatch_size > 0
-                if offset+minibatch_size > ds.length
+                if offset+minibatch_size > ds.length:
                     raise NotImplementedError()
             def __iter__(self):
                 return self
@@ -554,8 +587,8 @@
 
     TODO: automatically detect a chain of stacked datasets due to A | B | C | D ...
     """
-    def __init__(self,datasets,accept_nonunique_names=False):
-        DataSet.__init__(self)
+    def __init__(self,datasets,accept_nonunique_names=False,description=None,field_types=None):
+        DataSet.__init__(self,description,field_types)
         self.datasets=datasets
         self.accept_nonunique_names=accept_nonunique_names
         self.fieldname2dataset={}
@@ -596,11 +629,7 @@
     def fieldNames(self):
         return self.fieldname2dataset.keys()
             
-    def minibatches_nowrap(self,
-                           fieldnames = minibatches_fieldnames,
-                           minibatch_size = minibatches_minibatch_size,
-                           n_batches = minibatches_n_batches,
-                           offset = 0):
+    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
 
         class Iterator(object):
             def __init__(self,hsds,iterators):
@@ -699,11 +728,7 @@
         row_within_dataset = self.datasets_start_row[dataset_index]
         return dataset_index, row_within_dataset
         
-    def minibatches_nowrap(self,
-                           fieldnames = minibatches_fieldnames,
-                           minibatch_size = minibatches_minibatch_size,
-                           n_batches = minibatches_n_batches,
-                           offset = 0):
+    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
             
         class Iterator(object):
             def __init__(self,vsds):
@@ -762,7 +787,84 @@
                     self.move_to_next_dataset()
                 return 
                         
+class ArrayFieldsDataSet(DataSet):
+    """
+    Virtual super-class of datasets whose field values are numpy array,
+    thus defining valuesHStack and valuesVStack for sub-classes.
+    """
+    def __init__(self,description=None,field_types=None):
+        DataSet.__init__(self,description,field_types)
+    def valuesHStack(self,fieldnames,fieldvalues):
+        """Concatenate field values horizontally, e.g. two vectors
+        become a longer vector, two matrices become a wider matrix, etc."""
+        return numpy.hstack(fieldvalues)
+    def valuesVStack(self,fieldname,values):
+        """Concatenate field values vertically, e.g. two vectors
+        become a two-row matrix, two matrices become a longer matrix, etc."""
+        return numpy.vstack(values)
+
+class ArrayDataSet(ArrayFieldsDataSet):
+    """
+    An ArrayDataSet stores the fields as groups of columns in a numpy tensor,
+    whose first axis iterates over examples, second axis determines fields.
+    If the underlying array is N-dimensional (has N axes), then the field
+    values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2).
+    """
+
+    """
+    Construct an ArrayDataSet from the underlying numpy array (data) and
+    a map from fieldnames to field columns. The columns of a field are specified
+    using the standard arguments for indexing/slicing: integer for a column index,
+    slice for an interval of columns (with possible stride), or iterable of column indices.
+    """
+    def __init__(self, data_array, fields_names_columns):
+        self.data=data_array
+        self.fields=fields_names_columns
+
+        # check consistency and complete slices definitions
+        for fieldname, fieldcolumns in self.fields.items():
+            if type(fieldcolumns) is int:
+                assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
+            elif type(fieldcolumns) is slice:
+                start,step=None,None
+                if not fieldcolumns.start:
+                    start=0
+                if not fieldcolumns.step:
+                    step=1
+                if start or step:
+                    self.fields[fieldname]=slice(start,fieldcolumns.stop,step)
+            elif hasattr(fieldcolumns,"__iter__"): # something like a list
+                for i in fieldcolumns:
+                    assert i>=0 and i<data_array.shape[1]
+
+        def fieldNames(self):
+            return self.fields.keys()
+
+        def __len__(self):
+            return len(self.data)
+
+        #def __getitem__(self,i):
+        #    """More efficient implementation than the default"""
+            
+        def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+            class Iterator(LookupList): # store the result in the lookup-list values
+                def __init__(dataset,fieldnames,minibatch_size,n_batches,offset):
+                    if fieldnames is None: fieldnames = dataset.fieldNames()
+                    LookupList.__init__(self,fieldnames,[0]*len(fieldnames))
+                    self.dataset=dataset
+                    self.minibatch_size=minibatch_size
+                    assert offset>=0 and offset<len(dataset.data)
+                    assert offset+minibatch_size<len(dataset.data)
+                    self.current=offset
+                def __iter__(self):
+                    return self
+                def next(self):
+                    sub_data =  self.dataset.data[self.current:self.current+self.minibatch_size]
+                    self._values = [sub_data[:,self.dataset.fields[f]] for f in self._names]
+                    return self
                 
+            return Iterator(self,fieldnames,minibatch_size,n_batches,offset)
+        
 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
     """
     Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the