diff dataset.py @ 3:378b68d5c4ad

Added first (untested) version of ArrayDataSet
author bengioy@bengiomac.local
date Sun, 23 Mar 2008 14:41:22 -0400
parents 3fddb1c8f955
children f7dcfb5f9d5b
line wrap: on
line diff
--- a/dataset.py	Sat Mar 22 22:21:59 2008 -0400
+++ b/dataset.py	Sun Mar 23 14:41:22 2008 -0400
@@ -15,22 +15,22 @@
     def __init__(self):
         pass
 
-    def __iter__():
+    def __iter__(self):
         return self
 
-    def next():
+    def next(self):
         """Return the next example in the dataset."""
         raise NotImplementedError
 
-    def __getattr__(fieldname):
+    def __getattr__(self,fieldname):
         """Return a sub-dataset containing only the given fieldname as field."""
         return self(fieldname)
 
-    def __call__(*fieldnames):
+    def __call__(self,*fieldnames):
         """Return a sub-dataset containing only the given fieldnames as fields."""
         raise NotImplementedError
 
-    fieldNames(self):
+    def fieldNames(self):
         """Return the list of field names that are supported by getattr and getFields."""
         raise NotImplementedError
 
@@ -55,4 +55,132 @@
     def __getslice__(self,*slice_args):
         """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
         raise NotImplementedError
+
+# we may want ArrayDataSet defined in another python file
+
+from numpy import *
+
+class ArrayDataSet(FiniteDataSet):
+    """
+    A fixed-length and fixed-width dataset in which each element is a numpy.array
+    or a number, hence the whole dataset corresponds to a numpy.array. Fields
+    must correspond to a slice of columns. If the dataset has fields,
+    each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.
+    Any dataset can also be converted to a numpy.array (losing the notion of fields)
+    by the asarray(dataset) call.
+    """
+
+    def __self__(self,dataset=None,data=None,fields={}):
+        """
+        Construct an ArrayDataSet, either from a DataSet, or from
+        a numpy.array plus an optional specification of fields (by
+        a dictionary of column slices indexed by field names).
+        """
+        self.current_row=-1 # used for view of this dataset as an iterator
+        if dataset:
+            assert data==None and fields=={}
+            # convert dataset to an ArrayDataSet
+            raise NotImplementedError
+        if data:
+            assert dataset==None
+            self.data=data
+            self.fields=fields
+            self.width = data.shape[1]
+            for fieldname in fields:
+                fieldslice=fields[fieldname]
+                assert fieldslice.start>=0 and fieldslice.stop<=width)
+
+    def next(self):
+        """Return the next example in the dataset. If the dataset has fields,
+        the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array."""
+        if fields:
+            self.current_row+=1
+            if self.current_row==len(self.data):
+                self.current_row=0
+            return self[self.current_row]
+        else:
+            return self.data[self.current_row]
+
+    def __getattr__(self,fieldname):
+        """Return a sub-dataset containing only the given fieldname as field."""
+        data = self.fields[fieldname]
+        return ArrayDataSet(data=data)
+
+    def __call__(self,*fieldnames):
+        """Return a sub-dataset containing only the given fieldnames as fields."""
+        min_col=self.data.shape[1]
+        max_col=0
+        for field_slice in self.fields.values():
+            min_col=min(min_col,field_slice.start)
+            max_col=max(max_col,field_slice.stop)
+        new_fields={}
+        for field in self.fields:
+            new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step)
+        return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields)
+
+    def fieldNames(self):
+        """Return the list of field names that are supported by getattr and getFields."""
+        return self.fields.keys()
+
+    def __len__(self):
+        """len(dataset) returns the number of examples in the dataset."""
+        return len(self.data)
     
+    def __getitem__(self,i):
+        """
+        dataset[i] returns the (i+1)-th example of the dataset. If the dataset has fields
+        then a one-example dataset is returned (to be able to handle example.field accesses).
+        """
+        if self.fields:
+            if isinstance(i,slice):
+                return ArrayDataSet(data=data[slice],fields=self.fields)
+            return ArrayDataSet(data=self.data[i:i+1],fields=self.fields)
+        else:
+            return data[i]
+
+    def __getslice__(self,*slice_args):
+        """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
+        return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields)
+
+    def asarray(self):
+        if self.fields:
+            columns_used = zeros((self.data.shape[1]),dtype=bool)
+            for field_slice in self.fields.values():
+                for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
+                    columns_used[c]=True
+            # try to figure out if we can map all the slices into one slice:
+            mappable_to_one_slice = True
+            start=0
+            while start<len(columns_used) and not columns_used[start]:
+                start+=1
+            stop=len(columns_used)
+            while stop>0 and not columns_used[stop-1]:
+                stop-=1
+            step=0
+            i=start
+            while i<stop:
+                j=i+1
+                while not columns_used[j] and j<stop:
+                    j+=1
+                if step:
+                    if step!=j-i:
+                        mappable_to_one_slice = False
+                        break
+                else:
+                    step = j-i
+            if mappable_to_one_slice:
+                return data[slice(start,stop,step)]
+            # else make contiguous copy
+            n_columns = sum(columns_used)
+            result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
+            c=0
+            for field_slice in self.fields.values():
+               slice_width=field_slice.stop-field_slice.start
+               if field_slice.step:
+                   slice_width /= field_slice.step
+               # copy the field here
+               result[:,slice(c,slice_width)]=self.data[field_slice]
+               c+=slice_width
+            return result
+        return self.data
+