view dataset.py @ 4:f7dcfb5f9d5b

Added test for dataset.
author bengioy@bengiomac.local
date Sun, 23 Mar 2008 22:14:10 -0400
parents 378b68d5c4ad
children 8039918516fe
line wrap: on
line source


    
class DataSet(object):
    """
    This is a virtual base class or interface for datasets.
    A dataset is basically an iterator over examples. It does not necessarily
    have a fixed length (this is useful for 'streams' which feed on-line learning).
    Datasets with fixed and known length are FiniteDataSet, a subclass of DataSet.
    Examples and datasets have named fields. 
    One can obtain a sub-dataset by taking dataset.field or dataset(field1,field2,field3,...).
    Fields are not mutually exclusive, i.e. two fields can overlap in their actual content.
    The content of a field can be of any type, but often will be a numpy tensor.
    """

    def __init__(self):
        pass

    def __iter__(self):
        return self

    def next(self):
        """Return the next example in the dataset."""
        raise NotImplementedError

    def __getattr__(self,fieldname):
        """Return a sub-dataset containing only the given fieldname as field."""
        return self(fieldname)

    def __call__(self,*fieldnames):
        """Return a sub-dataset containing only the given fieldnames as fields."""
        raise NotImplementedError

    def fieldNames(self):
        """Return the list of field names that are supported by getattr and getFields."""
        raise NotImplementedError

class FiniteDataSet(DataSet):
    """
    Virtual interface, a subclass of DataSet for datasets which have a finite, known length.
    Examples are indexed by an integer between 0 and self.length()-1,
    and a subdataset can be obtained by slicing.
    """

    def __init__(self):
        pass

    def __len__(self):
        """len(dataset) returns the number of examples in the dataset."""
        raise NotImplementedError
    
    def __getitem__(self,i):
        """dataset[i] returns the (i+1)-th example of the dataset."""
        raise NotImplementedError

    def __getslice__(self,*slice_args):
        """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
        raise NotImplementedError

# we may want ArrayDataSet defined in another python file

import numpy

class ArrayDataSet(FiniteDataSet):
    """
    A fixed-length and fixed-width dataset in which each element is a numpy.array
    or a number, hence the whole dataset corresponds to a numpy.array. Fields
    must correspond to a slice of columns. If the dataset has fields,
    each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.
    Any dataset can also be converted to a numpy.array (losing the notion of fields)
    by the asarray(dataset) call.
    """

    def __init__(self,dataset=None,data=None,fields={}):
        """
        Construct an ArrayDataSet, either from a DataSet, or from
        a numpy.array plus an optional specification of fields (by
        a dictionary of column slices indexed by field names).
        """
        self.current_row=-1 # used for view of this dataset as an iterator
        if dataset!=None:
            assert data==None and fields=={}
            # convert dataset to an ArrayDataSet
            raise NotImplementedError
        if data!=None:
            assert dataset==None
            self.data=data
            self.fields=fields
            self.width = data.shape[1]
            for fieldname in fields:
                fieldslice=fields[fieldname]
                # make sure fieldslice.start and fieldslice.step are defined
                start=fieldslice.start
                step=fieldslice.step
                if not start:
                    start=0
                if not step:
                    step=1
                if not fieldslice.start or not fieldslice.step:
                    fieldslice = slice(start,fieldslice.stop,step)
                # and coherent with the data array
                assert fieldslice.start>=0 and fieldslice.stop<=self.width

    def next(self):
        """
        Return the next example in the dataset. If the dataset has fields,
        the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.
        """
        if self.fields:
            self.current_row+=1
            if self.current_row==len(self.data):
                self.current_row=-1
                raise StopIteration
            return self[self.current_row]
        else:
            return self.data[self.current_row]

    def __getattr__(self,fieldname):
        """Return a sub-dataset containing only the given fieldname as field."""
        data=self.data[self.fields[fieldname]]                
        if len(data)==1:
            return data
        else:
            return ArrayDataSet(data=data)

    def __call__(self,*fieldnames):
        """Return a sub-dataset containing only the given fieldnames as fields."""
        min_col=self.data.shape[1]
        max_col=0
        for field_slice in self.fields.values():
            min_col=min(min_col,field_slice.start)
            max_col=max(max_col,field_slice.stop)
        new_fields={}
        for field in self.fields:
            new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step)
        return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields)

    def fieldNames(self):
        """Return the list of field names that are supported by getattr and getFields."""
        return self.fields.keys()

    def __len__(self):
        """len(dataset) returns the number of examples in the dataset."""
        return len(self.data)
    
    def __getitem__(self,i):
        """
        dataset[i] returns the (i+1)-th example of the dataset. If the dataset has fields
        then a one-example dataset is returned (to be able to handle example.field accesses).
        """
        if self.fields:
            if isinstance(i,slice):
                return ArrayDataSet(data=data[slice],fields=self.fields)
            return ArrayDataSet(data=self.data[i:i+1],fields=self.fields)
        else:
            return data[i]

    def __getslice__(self,*slice_args):
        """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
        return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields)

    def asarray(self):
        if self.fields:
            columns_used = numpy.zeros((self.data.shape[1]),dtype=bool)
            for field_slice in self.fields.values():
                for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
                    columns_used[c]=True
            # try to figure out if we can map all the slices into one slice:
            mappable_to_one_slice = True
            start=0
            while start<len(columns_used) and not columns_used[start]:
                start+=1
            stop=len(columns_used)
            while stop>0 and not columns_used[stop-1]:
                stop-=1
            step=0
            i=start
            while i<stop:
                j=i+1
                while not columns_used[j] and j<stop:
                    j+=1
                if step:
                    if step!=j-i:
                        mappable_to_one_slice = False
                        break
                else:
                    step = j-i
            if mappable_to_one_slice:
                return data[slice(start,stop,step)]
            # else make contiguous copy
            n_columns = sum(columns_used)
            result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
            c=0
            for field_slice in self.fields.values():
               slice_width=field_slice.stop-field_slice.start/field_slice.step
               # copy the field here
               result[:,slice(c,slice_width)]=self.data[field_slice]
               c+=slice_width
            return result
        return self.data