Mercurial > pylearn
view dataset.py @ 8:d1c394486037
Replaced asarray() method by __array__ method which gets called automatically when
trying to cast into an array or numpy array or upon numpy.asarray(dataset).
author | bengioy@bengiomac.local |
---|---|
date | Mon, 24 Mar 2008 15:56:53 -0400 |
parents | 6f8f338686db |
children | de616c423dbd |
line wrap: on
line source
class DataSet(object): """ This is a virtual base class or interface for datasets. A dataset is basically an iterator over examples. It does not necessarily have a fixed length (this is useful for 'streams' which feed on-line learning). Datasets with fixed and known length are FiniteDataSet, a subclass of DataSet. Examples and datasets optionally have named fields. One can obtain a sub-dataset by taking dataset.field or dataset(field1,field2,field3,...). Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. The content of a field can be of any type, but often will be a numpy array. The minibatch_size field, if different than 1, means that the iterator (next() method) returns not a single example but an array of length minibatch_size, i.e., an indexable object. """ def __init__(self,minibatch_size=1): assert minibatch_size>0 self.minibatch_size=minibatch_size def __iter__(self): """ Return an iterator, whose next() method returns the next example or the next minibatch in the dataset. A minibatch (of length > 1) should be something one can iterate on again in order to obtain the individual examples. If the dataset has fields, then the example or the minibatch must have the same fields (typically this is implemented by returning another (small) dataset, when there are fields). """ raise NotImplementedError def __getattr__(self,fieldname): """Return a sub-dataset containing only the given fieldname as field.""" return self(fieldname) def __call__(self,*fieldnames): """Return a sub-dataset containing only the given fieldnames as fields.""" raise NotImplementedError def fieldNames(self): """Return the list of field names that are supported by getattr and getFields.""" raise NotImplementedError class FiniteDataSet(DataSet): """ Virtual interface, a subclass of DataSet for datasets which have a finite, known length. Examples are indexed by an integer between 0 and self.length()-1, and a subdataset can be obtained by slicing. """ def __init__(self,minibatch_size): DataSet.__init__(self,minibatch_size) def __iter__(self): return FiniteDataSetIterator(self) def __len__(self): """len(dataset) returns the number of examples in the dataset.""" raise NotImplementedError def __getitem__(self,i): """dataset[i] returns the (i+1)-th example of the dataset.""" raise NotImplementedError def __getslice__(self,*slice_args): """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" raise NotImplementedError class FiniteDataSetIterator(object): def __init__(self,dataset): self.dataset=dataset self.current = -self.dataset.minibatch_size def next(self): """ Return the next example(s) in the dataset. If self.dataset.minibatch_size>1 return that many examples. If the dataset has fields, the example or the minibatch of examples is just a minibatch_size-rows ArrayDataSet (so that the fields can be accessed), but that resulting mini-dataset has a minibatch_size of 1, so that one can iterate example-wise on it. On the other hand, if the dataset has no fields (e.g. because it is already the field of a bigger dataset), then the returned example or minibatch may be any indexable object, such as a numpy array. Following the array semantics of indexing and slicing, if the minibatch_size is 1 (and there are no fields), then the result is an array with one less dimension (e.g., a vector, if the dataset is a matrix), corresponding to a row. Again, if the minibatch_size is >1, one can iterate on the result to obtain individual examples (as rows). """ self.current+=self.dataset.minibatch_size if self.current>=len(self.dataset): self.current=-self.dataset.minibatch_size raise StopIteration if self.dataset.minibatch_size==1: return self.dataset[self.current] else: return self.dataset[self.current:self.current+self.dataset.minibatch_size] # we may want ArrayDataSet defined in another python file import numpy class ArrayDataSet(FiniteDataSet): """ A fixed-length and fixed-width dataset in which each element is a numpy array or a number, hence the whole dataset corresponds to a numpy array. Fields must correspond to a slice of columns. If the dataset has fields, each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array. Any dataset can also be converted to a numpy array (losing the notion of fields) by the numpy.array(dataset) call. """ def __init__(self,dataset=None,data=None,fields={},minibatch_size=1): """ Construct an ArrayDataSet, either from a DataSet, or from a numpy array plus an optional specification of fields (by a dictionary of column slices indexed by field names). """ FiniteDataSet.__init__(self,minibatch_size) if dataset!=None: assert data==None and fields=={} # convert dataset to an ArrayDataSet raise NotImplementedError if data!=None: assert dataset==None self.data=data self.fields=fields self.width = data.shape[1] for fieldname in fields: fieldslice=fields[fieldname] # make sure fieldslice.start and fieldslice.step are defined start=fieldslice.start step=fieldslice.step if not start: start=0 if not step: step=1 if not fieldslice.start or not fieldslice.step: fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step) # and coherent with the data array assert fieldslice.start>=0 and fieldslice.stop<=self.width assert minibatch_size<=len(self.data) def __getattr__(self,fieldname): """ Return a numpy array with the content associated with the given field name. If this is a one-example dataset, then a row, i.e., numpy array (of one less dimension than the dataset.data) is returned. """ if len(self.data)==1: return self.data[0,self.fields[fieldname]] return self.data[:,self.fields[fieldname]] def __call__(self,*fieldnames): """Return a sub-dataset containing only the given fieldnames as fields.""" min_col=self.data.shape[1] max_col=0 for field_slice in self.fields.values(): min_col=min(min_col,field_slice.start) max_col=max(max_col,field_slice.stop) new_fields={} for field in self.fields: new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step) return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields,minibatch_size=self.minibatch_size) def fieldNames(self): """Return the list of field names that are supported by getattr and getFields.""" return self.fields.keys() def __len__(self): """len(dataset) returns the number of examples in the dataset.""" return len(self.data) def __getitem__(self,i): """ dataset[i] returns the (i+1)-th example of the dataset. If the dataset has fields then a one-example dataset is returned (to be able to handle example.field accesses). """ if self.fields: if isinstance(i,slice): return ArrayDataSet(data=data[slice],fields=self.fields) return ArrayDataSet(data=self.data[i:i+1],fields=self.fields) else: return self.data[i] def __getslice__(self,*slice_args): """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" return ArrayDataSet(data=self.data[apply(slice,slice_args)],fields=self.fields) def __array__(self): if not self.fields: return self.data # else, select subsets of columns mapped by the fields columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) for field_slice in self.fields.values(): for c in xrange(field_slice.start,field_slice.stop,field_slice.step): columns_used[c]=True # try to figure out if we can map all the slices into one slice: mappable_to_one_slice = True start=0 while start<len(columns_used) and not columns_used[start]: start+=1 stop=len(columns_used) while stop>0 and not columns_used[stop-1]: stop-=1 step=0 i=start while i<stop: j=i+1 while j<stop and not columns_used[j]: j+=1 if step: if step!=j-i: mappable_to_one_slice = False break else: step = j-i i=j if mappable_to_one_slice: return self.data[:,slice(start,stop,step)] # else make contiguous copy n_columns = sum(columns_used) result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype) print result.shape c=0 for field_slice in self.fields.values(): slice_width=field_slice.stop-field_slice.start/field_slice.step # copy the field here result[:,slice(c,slice_width)]=self.data[:,field_slice] c+=slice_width return result