view dataset.py @ 9:de616c423dbd

Improving comments in dataset.py
author bengioy@esprit.iro.umontreal.ca
date Mon, 24 Mar 2008 16:52:47 -0400
parents d1c394486037
children be128b9127c8 88168361a5ab
line wrap: on
line source


    
class DataSet(object):
    """
    This is a virtual base class or interface for datasets.
    A dataset is basically an iterator over examples. It does not necessarily
    have a fixed length (this is useful for 'streams' which feed on-line learning).
    Datasets with fixed and known length are FiniteDataSet, a subclass of DataSet.
    Examples and datasets optionally have named fields. 
    One can obtain a sub-dataset by taking dataset.field or dataset(field1,field2,field3,...).
    Fields are not mutually exclusive, i.e. two fields can overlap in their actual content.
    The content of a field can be of any type, but often will be a numpy array.
    The minibatch_size attribute, if different than 1, means that the iterator (next() method)
    returns not a single example but an array of length minibatch_size, i.e., an indexable
    object with minibatch_size examples in it.
    """

    def __init__(self,minibatch_size=1):
        assert minibatch_size>0
        self.minibatch_size=minibatch_size

    def __iter__(self):
        """
        Return an iterator, whose next() method returns the next example or the next 
        minibatch in the dataset. A minibatch (of length > 1) should be something one 
        can iterate on again in order to obtain the individual examples. If the dataset 
        has fields, then the example or the minibatch must have the same fields
        (typically this is implemented by returning another smaller dataset, when
        there are fields).
        """
        raise NotImplementedError

    def __getattr__(self,fieldname):
        """Return a sub-dataset containing only the given fieldname as field."""
        return self(fieldname)

    def __call__(self,*fieldnames):
        """Return a sub-dataset containing only the given fieldnames as fields."""
        raise NotImplementedError

    def fieldNames(self):
        """Return the list of field names that are supported by getattr and getFields."""
        raise NotImplementedError

class FiniteDataSet(DataSet):
    """
    Virtual interface, a subclass of DataSet for datasets which have a finite, known length.
    Examples are indexed by an integer between 0 and self.length()-1,
    and a subdataset can be obtained by slicing.
    """

    def __init__(self,minibatch_size):
        DataSet.__init__(self,minibatch_size)

    def __iter__(self):
        return FiniteDataSetIterator(self)
    
    def __len__(self):
        """len(dataset) returns the number of examples in the dataset."""
        raise NotImplementedError
    
    def __getitem__(self,i):
        """dataset[i] returns the (i+1)-th example of the dataset."""
        raise NotImplementedError

    def __getslice__(self,*slice_args):
        """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
        raise NotImplementedError

class FiniteDataSetIterator(object):
    def __init__(self,dataset):
        self.dataset=dataset
        self.current = -self.dataset.minibatch_size
        
    def next(self):
        """
        Return the next example(s) in the dataset. If self.dataset.minibatch_size>1 return that
        many examples. If the dataset has fields, the example or the minibatch of examples
        is just a minibatch_size-rows ArrayDataSet (so that the fields can be accessed),
        but that resulting mini-dataset has a minibatch_size of 1, so that one can iterate
        example-wise on it. On the other hand, if the dataset has no fields (e.g. because
        it is already the field of a bigger dataset), then the returned example or minibatch
        may be any indexable object, such as a numpy array. Following the array semantics of indexing
        and slicing, if the minibatch_size is 1 (and there are no fields), then the result is an array
        with one less dimension (e.g., a vector, if the dataset is a matrix), corresponding
        to a row. Again, if the minibatch_size is >1, one can iterate on the result to
        obtain individual examples (as rows).
        """
        self.current+=self.dataset.minibatch_size
        if self.current>=len(self.dataset):
            self.current=-self.dataset.minibatch_size
            raise StopIteration
        if self.dataset.minibatch_size==1:
            return self.dataset[self.current]
        else:
            return self.dataset[self.current:self.current+self.dataset.minibatch_size]


# we may want ArrayDataSet defined in another python file

import numpy

class ArrayDataSet(FiniteDataSet):
    """
    An ArrayDataSet behaves like a numpy array but adds the notion of fields
    and minibatch_size from DataSet. It is a  fixed-length and fixed-width dataset 
    in which each element is a numpy array or a number, hence the whole 
    dataset corresponds to a numpy array. Fields
    must correspond to a slice of array columns. If the dataset has fields,
    each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array.
    Any dataset can also be converted to a numpy array (losing the notion of fields
    and of minibatch_size) by the numpy.array(dataset) call.
    """

    def __init__(self,dataset=None,data=None,fields={},minibatch_size=1):
        """
	There are two ways to construct an ArrayDataSet: (1) from an
	existing dataset (which may result in a copy of the data in a numpy array),
	or (2) from a numpy.array (the data argument), along with an optional description
	of the fields (dictionary of column slices indexed by field names).
        """
        FiniteDataSet.__init__(self,minibatch_size)
        if dataset!=None:
            assert data==None and fields=={}
            # convert dataset to an ArrayDataSet
            raise NotImplementedError
        if data!=None:
            assert dataset==None
            self.data=data
            self.fields=fields
            self.width = data.shape[1]
            for fieldname in fields:
                fieldslice=fields[fieldname]
                # make sure fieldslice.start and fieldslice.step are defined
                start=fieldslice.start
                step=fieldslice.step
                if not start:
                    start=0
                if not step:
                    step=1
                if not fieldslice.start or not fieldslice.step:
                    fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step)
                # and coherent with the data array
                assert fieldslice.start>=0 and fieldslice.stop<=self.width
        assert minibatch_size<=len(self.data)

    def __getattr__(self,fieldname):
        """
        Return a numpy array with the content associated with the given field name.
        If this is a one-example dataset, then a row, i.e., numpy array (of one less dimension
        than the dataset.data) is returned.
        """
        if len(self.data)==1:
            return self.data[0,self.fields[fieldname]]
        return self.data[:,self.fields[fieldname]]

    def __call__(self,*fieldnames):
        """Return a sub-dataset containing only the given fieldnames as fields."""
        min_col=self.data.shape[1]
        max_col=0
        for field_slice in self.fields.values():
            min_col=min(min_col,field_slice.start)
            max_col=max(max_col,field_slice.stop)
        new_fields={}
        for field in self.fields:
            new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step)
        return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields,minibatch_size=self.minibatch_size)

    def fieldNames(self):
        """Return the list of field names that are supported by getattr and getFields."""
        return self.fields.keys()

    def __len__(self):
        """len(dataset) returns the number of examples in the dataset."""
        return len(self.data)
    
    def __getitem__(self,i):
        """
        dataset[i] returns the (i+1)-th example of the dataset. If the dataset has fields
        then a one-example dataset is returned (to be able to handle example.field accesses).
        """
        if self.fields:
            if isinstance(i,slice):
                return ArrayDataSet(data=data[slice],fields=self.fields)
            return ArrayDataSet(data=self.data[i:i+1],fields=self.fields)
        else:
            return self.data[i]

    def __getslice__(self,*slice_args):
        """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
        return ArrayDataSet(data=self.data[apply(slice,slice_args)],fields=self.fields)

    def __array__(self):
        if not self.fields:
            return self.data
        # else, select subsets of columns mapped by the fields
        columns_used = numpy.zeros((self.data.shape[1]),dtype=bool)
        for field_slice in self.fields.values():
            for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
                columns_used[c]=True
        # try to figure out if we can map all the slices into one slice:
        mappable_to_one_slice = True
        start=0
        while start<len(columns_used) and not columns_used[start]:
            start+=1
        stop=len(columns_used)
        while stop>0 and not columns_used[stop-1]:
            stop-=1
        step=0
        i=start
        while i<stop:
            j=i+1
            while j<stop and not columns_used[j]:
                j+=1
            if step:
                if step!=j-i:
                    mappable_to_one_slice = False
                    break
            else:
                step = j-i
            i=j
        if mappable_to_one_slice:
            return self.data[:,slice(start,stop,step)]
        # else make contiguous copy
        n_columns = sum(columns_used)
        result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
        print result.shape
        c=0
        for field_slice in self.fields.values():
            slice_width=field_slice.stop-field_slice.start/field_slice.step
            # copy the field here
            result[:,slice(c,slice_width)]=self.data[:,field_slice]
            c+=slice_width
        return result