view dataset.py @ 6:d5738b79089a

Removed MinibatchIterator and instead made minibatch_size a field of all DataSets, so that they can all iterate over minibatches, optionally.
author bengioy@bengiomac.local
date Mon, 24 Mar 2008 09:04:06 -0400
parents 8039918516fe
children 6f8f338686db
line wrap: on
line source


    
class DataSet(object):
    """
    This is a virtual base class or interface for datasets.
    A dataset is basically an iterator over examples. It does not necessarily
    have a fixed length (this is useful for 'streams' which feed on-line learning).
    Datasets with fixed and known length are FiniteDataSet, a subclass of DataSet.
    Examples and datasets optionally have named fields. 
    One can obtain a sub-dataset by taking dataset.field or dataset(field1,field2,field3,...).
    Fields are not mutually exclusive, i.e. two fields can overlap in their actual content.
    The content of a field can be of any type, but often will be a numpy array.
    The minibatch_size field, if different than 1, means that the iterator (next() method)
    returns not a single example but an array of length minibatch_size, i.e., an indexable
    object.
    """

    def __init__(self,minibatch_size=1):
        assert minibatch_size>0
        self.minibatch_size=minibatch_size

    def __iter__(self):
        return self

    def next(self):
        """
        Return the next example or the next minibatch in the dataset.
        A minibatch (of length > 1) should be something one can iterate on again in order
        to obtain the individual examples. If the dataset has fields,
        then the example or the minibatch must have the same fields
        (typically this is implemented by returning another (small) dataset, when
        there are fields).
        """
        raise NotImplementedError

    def __getattr__(self,fieldname):
        """Return a sub-dataset containing only the given fieldname as field."""
        return self(fieldname)

    def __call__(self,*fieldnames):
        """Return a sub-dataset containing only the given fieldnames as fields."""
        raise NotImplementedError

    def fieldNames(self):
        """Return the list of field names that are supported by getattr and getFields."""
        raise NotImplementedError

class FiniteDataSet(DataSet):
    """
    Virtual interface, a subclass of DataSet for datasets which have a finite, known length.
    Examples are indexed by an integer between 0 and self.length()-1,
    and a subdataset can be obtained by slicing.
    """

    def __init__(self,minibatch_size):
        DataSet.__init__(self,minibatch_size)

    def __len__(self):
        """len(dataset) returns the number of examples in the dataset."""
        raise NotImplementedError
    
    def __getitem__(self,i):
        """dataset[i] returns the (i+1)-th example of the dataset."""
        raise NotImplementedError

    def __getslice__(self,*slice_args):
        """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
        raise NotImplementedError

# we may want ArrayDataSet defined in another python file

import numpy

class ArrayDataSet(FiniteDataSet):
    """
    A fixed-length and fixed-width dataset in which each element is a numpy array
    or a number, hence the whole dataset corresponds to a numpy array. Fields
    must correspond to a slice of columns. If the dataset has fields,
    each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array.
    Any dataset can also be converted to a numpy array (losing the notion of fields)
    by the asarray(dataset) call.
    """

    def __init__(self,dataset=None,data=None,fields={},minibatch_size=1):
        """
        Construct an ArrayDataSet, either from a DataSet, or from
        a numpy array plus an optional specification of fields (by
        a dictionary of column slices indexed by field names).
        """
        FiniteDataSet.__init__(self,minibatch_size)
        self.current_row=-1 # used for view of this dataset as an iterator
        if dataset!=None:
            assert data==None and fields=={}
            # convert dataset to an ArrayDataSet
            raise NotImplementedError
        if data!=None:
            assert dataset==None
            self.data=data
            self.fields=fields
            self.width = data.shape[1]
            for fieldname in fields:
                fieldslice=fields[fieldname]
                # make sure fieldslice.start and fieldslice.step are defined
                start=fieldslice.start
                step=fieldslice.step
                if not start:
                    start=0
                if not step:
                    step=1
                if not fieldslice.start or not fieldslice.step:
                    fieldslice = slice(start,fieldslice.stop,step)
                # and coherent with the data array
                assert fieldslice.start>=0 and fieldslice.stop<=self.width
        assert minibatch_size<=len(self.data)

    def next(self):
        """
        Return the next example(s) in the dataset. If self.minibatch_size>1 return that
        many examples. If the dataset has fields, the example or the minibatch of examples
        is just a minibatch_size-rows ArrayDataSet (so that the fields can be accessed),
        but that resulting mini-dataset has a minibatch_size of 1, so that one can iterate
        example-wise on it. On the other hand, if the dataset has no fields (e.g. because
        it is already the field of a bigger dataset), then the returned example or minibatch
        is a numpy array. Following the array semantics of indexing and slicing,
        if the minibatch_size is 1 (and there are no fields), then the result is an array
        with one less dimension (e.g., a vector, if the dataset is a matrix), corresponding
        to a row. Again, if the minibatch_size is >1, one can iterate on the result to
        obtain individual examples (as rows).
        """
        if self.fields:
            self.current_row+=self.minibatch_size
            if self.current_row>=len(self.data):
                self.current_row=-self.minibatch_size
                raise StopIteration
            if self.minibatch_size==1:
                return self[self.current_row]
            else:
                return self[self.current_row:self.current_row+self.minibatch_size]
        else:
            if self.minibatch_size==1:
                return self.data[self.current_row]
            else:
                return self.data[self.current_row:self.current_row+self.minibatch_size]

    def __getattr__(self,fieldname):
        """Return a numpy array with the content associated with the given field name."""
        return self.data[self.fields[fieldname]]

    def __call__(self,*fieldnames):
        """Return a sub-dataset containing only the given fieldnames as fields."""
        min_col=self.data.shape[1]
        max_col=0
        for field_slice in self.fields.values():
            min_col=min(min_col,field_slice.start)
            max_col=max(max_col,field_slice.stop)
        new_fields={}
        for field in self.fields:
            new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step)
        return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields,minibatch_size=self.minibatch_size)

    def fieldNames(self):
        """Return the list of field names that are supported by getattr and getFields."""
        return self.fields.keys()

    def __len__(self):
        """len(dataset) returns the number of examples in the dataset."""
        return len(self.data)
    
    def __getitem__(self,i):
        """
        dataset[i] returns the (i+1)-th example of the dataset. If the dataset has fields
        then a one-example dataset is returned (to be able to handle example.field accesses).
        """
        if self.fields:
            if isinstance(i,slice):
                return ArrayDataSet(data=data[slice],fields=self.fields)
            return ArrayDataSet(data=self.data[i:i+1],fields=self.fields)
        else:
            return data[i]

    def __getslice__(self,*slice_args):
        """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
        return ArrayDataSet(data=self.data[apply(slice,slice_args)],fields=self.fields)

    def asarray(self):
        if self.fields:
            columns_used = numpy.zeros((self.data.shape[1]),dtype=bool)
            for field_slice in self.fields.values():
                for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
                    columns_used[c]=True
            # try to figure out if we can map all the slices into one slice:
            mappable_to_one_slice = True
            start=0
            while start<len(columns_used) and not columns_used[start]:
                start+=1
            stop=len(columns_used)
            while stop>0 and not columns_used[stop-1]:
                stop-=1
            step=0
            i=start
            while i<stop:
                j=i+1
                while not columns_used[j] and j<stop:
                    j+=1
                if step:
                    if step!=j-i:
                        mappable_to_one_slice = False
                        break
                else:
                    step = j-i
            if mappable_to_one_slice:
                return data[slice(start,stop,step)]
            # else make contiguous copy
            n_columns = sum(columns_used)
            result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
            c=0
            for field_slice in self.fields.values():
               slice_width=field_slice.stop-field_slice.start/field_slice.step
               # copy the field here
               result[:,slice(c,slice_width)]=self.data[field_slice]
               c+=slice_width
            return result
        return self.data