Mercurial > pylearn
view dataset.py @ 5:8039918516fe
Added MinibatchIterator
author | bengioy@bengiomac.local |
---|---|
date | Sun, 23 Mar 2008 22:44:43 -0400 |
parents | f7dcfb5f9d5b |
children | d5738b79089a |
line wrap: on
line source
class DataSet(object): """ This is a virtual base class or interface for datasets. A dataset is basically an iterator over examples. It does not necessarily have a fixed length (this is useful for 'streams' which feed on-line learning). Datasets with fixed and known length are FiniteDataSet, a subclass of DataSet. Examples and datasets have named fields. One can obtain a sub-dataset by taking dataset.field or dataset(field1,field2,field3,...). Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. The content of a field can be of any type, but often will be a numpy tensor. """ def __init__(self): pass def __iter__(self): return self def next(self): """Return the next example in the dataset.""" raise NotImplementedError def __getattr__(self,fieldname): """Return a sub-dataset containing only the given fieldname as field.""" return self(fieldname) def __call__(self,*fieldnames): """Return a sub-dataset containing only the given fieldnames as fields.""" raise NotImplementedError def fieldNames(self): """Return the list of field names that are supported by getattr and getFields.""" raise NotImplementedError class FiniteDataSet(DataSet): """ Virtual interface, a subclass of DataSet for datasets which have a finite, known length. Examples are indexed by an integer between 0 and self.length()-1, and a subdataset can be obtained by slicing. """ def __init__(self): pass def __len__(self): """len(dataset) returns the number of examples in the dataset.""" raise NotImplementedError def __getitem__(self,i): """dataset[i] returns the (i+1)-th example of the dataset.""" raise NotImplementedError def __getslice__(self,*slice_args): """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" raise NotImplementedError def minibatches(self,minibatch_size): """Return an iterator for the dataset that goes through minibatches of the given size.""" return MinibatchIterator(self,minibatch_size) class MinibatchIterator(object): """ Iterator class for FiniteDataSet that can iterate by minibatches (sub-dataset of consecutive examples). """ def __init__(self,dataset,minibatch_size): assert minibatch_size>0 and minibatch_size<len(dataset) self.dataset=dataset self.minibatch_size=minibatch_size self.current=-minibatch_size def __iter__(self): return self def next(self): self.current+=self.minibatch_size if self.current>=len(self.dataset): self.current=-self.minibatchsize raise StopIteration return self.dataset[self.current:self.current+self.minibatchsize] # we may want ArrayDataSet defined in another python file import numpy class ArrayDataSet(FiniteDataSet): """ A fixed-length and fixed-width dataset in which each element is a numpy.array or a number, hence the whole dataset corresponds to a numpy.array. Fields must correspond to a slice of columns. If the dataset has fields, each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. Any dataset can also be converted to a numpy.array (losing the notion of fields) by the asarray(dataset) call. """ def __init__(self,dataset=None,data=None,fields={}): """ Construct an ArrayDataSet, either from a DataSet, or from a numpy.array plus an optional specification of fields (by a dictionary of column slices indexed by field names). """ self.current_row=-1 # used for view of this dataset as an iterator if dataset!=None: assert data==None and fields=={} # convert dataset to an ArrayDataSet raise NotImplementedError if data!=None: assert dataset==None self.data=data self.fields=fields self.width = data.shape[1] for fieldname in fields: fieldslice=fields[fieldname] # make sure fieldslice.start and fieldslice.step are defined start=fieldslice.start step=fieldslice.step if not start: start=0 if not step: step=1 if not fieldslice.start or not fieldslice.step: fieldslice = slice(start,fieldslice.stop,step) # and coherent with the data array assert fieldslice.start>=0 and fieldslice.stop<=self.width def next(self): """ Return the next example in the dataset. If the dataset has fields, the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. """ if self.fields: self.current_row+=1 if self.current_row==len(self.data): self.current_row=-1 raise StopIteration return self[self.current_row] else: return self.data[self.current_row] def __getattr__(self,fieldname): """Return a sub-dataset containing only the given fieldname as field.""" data=self.data[self.fields[fieldname]] if len(data)==1: return data else: return ArrayDataSet(data=data) def __call__(self,*fieldnames): """Return a sub-dataset containing only the given fieldnames as fields.""" min_col=self.data.shape[1] max_col=0 for field_slice in self.fields.values(): min_col=min(min_col,field_slice.start) max_col=max(max_col,field_slice.stop) new_fields={} for field in self.fields: new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step) return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields) def fieldNames(self): """Return the list of field names that are supported by getattr and getFields.""" return self.fields.keys() def __len__(self): """len(dataset) returns the number of examples in the dataset.""" return len(self.data) def __getitem__(self,i): """ dataset[i] returns the (i+1)-th example of the dataset. If the dataset has fields then a one-example dataset is returned (to be able to handle example.field accesses). """ if self.fields: if isinstance(i,slice): return ArrayDataSet(data=data[slice],fields=self.fields) return ArrayDataSet(data=self.data[i:i+1],fields=self.fields) else: return data[i] def __getslice__(self,*slice_args): """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields) def asarray(self): if self.fields: columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) for field_slice in self.fields.values(): for c in xrange(field_slice.start,field_slice.stop,field_slice.step): columns_used[c]=True # try to figure out if we can map all the slices into one slice: mappable_to_one_slice = True start=0 while start<len(columns_used) and not columns_used[start]: start+=1 stop=len(columns_used) while stop>0 and not columns_used[stop-1]: stop-=1 step=0 i=start while i<stop: j=i+1 while not columns_used[j] and j<stop: j+=1 if step: if step!=j-i: mappable_to_one_slice = False break else: step = j-i if mappable_to_one_slice: return data[slice(start,stop,step)] # else make contiguous copy n_columns = sum(columns_used) result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype) c=0 for field_slice in self.fields.values(): slice_width=field_slice.stop-field_slice.start/field_slice.step # copy the field here result[:,slice(c,slice_width)]=self.data[field_slice] c+=slice_width return result return self.data