Mercurial > pylearn
diff dataset.py @ 3:378b68d5c4ad
Added first (untested) version of ArrayDataSet
author | bengioy@bengiomac.local |
---|---|
date | Sun, 23 Mar 2008 14:41:22 -0400 |
parents | 3fddb1c8f955 |
children | f7dcfb5f9d5b |
line wrap: on
line diff
--- a/dataset.py Sat Mar 22 22:21:59 2008 -0400 +++ b/dataset.py Sun Mar 23 14:41:22 2008 -0400 @@ -15,22 +15,22 @@ def __init__(self): pass - def __iter__(): + def __iter__(self): return self - def next(): + def next(self): """Return the next example in the dataset.""" raise NotImplementedError - def __getattr__(fieldname): + def __getattr__(self,fieldname): """Return a sub-dataset containing only the given fieldname as field.""" return self(fieldname) - def __call__(*fieldnames): + def __call__(self,*fieldnames): """Return a sub-dataset containing only the given fieldnames as fields.""" raise NotImplementedError - fieldNames(self): + def fieldNames(self): """Return the list of field names that are supported by getattr and getFields.""" raise NotImplementedError @@ -55,4 +55,132 @@ def __getslice__(self,*slice_args): """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" raise NotImplementedError + +# we may want ArrayDataSet defined in another python file + +from numpy import * + +class ArrayDataSet(FiniteDataSet): + """ + A fixed-length and fixed-width dataset in which each element is a numpy.array + or a number, hence the whole dataset corresponds to a numpy.array. Fields + must correspond to a slice of columns. If the dataset has fields, + each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. + Any dataset can also be converted to a numpy.array (losing the notion of fields) + by the asarray(dataset) call. + """ + + def __self__(self,dataset=None,data=None,fields={}): + """ + Construct an ArrayDataSet, either from a DataSet, or from + a numpy.array plus an optional specification of fields (by + a dictionary of column slices indexed by field names). + """ + self.current_row=-1 # used for view of this dataset as an iterator + if dataset: + assert data==None and fields=={} + # convert dataset to an ArrayDataSet + raise NotImplementedError + if data: + assert dataset==None + self.data=data + self.fields=fields + self.width = data.shape[1] + for fieldname in fields: + fieldslice=fields[fieldname] + assert fieldslice.start>=0 and fieldslice.stop<=width) + + def next(self): + """Return the next example in the dataset. If the dataset has fields, + the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.""" + if fields: + self.current_row+=1 + if self.current_row==len(self.data): + self.current_row=0 + return self[self.current_row] + else: + return self.data[self.current_row] + + def __getattr__(self,fieldname): + """Return a sub-dataset containing only the given fieldname as field.""" + data = self.fields[fieldname] + return ArrayDataSet(data=data) + + def __call__(self,*fieldnames): + """Return a sub-dataset containing only the given fieldnames as fields.""" + min_col=self.data.shape[1] + max_col=0 + for field_slice in self.fields.values(): + min_col=min(min_col,field_slice.start) + max_col=max(max_col,field_slice.stop) + new_fields={} + for field in self.fields: + new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step) + return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields) + + def fieldNames(self): + """Return the list of field names that are supported by getattr and getFields.""" + return self.fields.keys() + + def __len__(self): + """len(dataset) returns the number of examples in the dataset.""" + return len(self.data) + def __getitem__(self,i): + """ + dataset[i] returns the (i+1)-th example of the dataset. If the dataset has fields + then a one-example dataset is returned (to be able to handle example.field accesses). + """ + if self.fields: + if isinstance(i,slice): + return ArrayDataSet(data=data[slice],fields=self.fields) + return ArrayDataSet(data=self.data[i:i+1],fields=self.fields) + else: + return data[i] + + def __getslice__(self,*slice_args): + """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" + return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields) + + def asarray(self): + if self.fields: + columns_used = zeros((self.data.shape[1]),dtype=bool) + for field_slice in self.fields.values(): + for c in xrange(field_slice.start,field_slice.stop,field_slice.step): + columns_used[c]=True + # try to figure out if we can map all the slices into one slice: + mappable_to_one_slice = True + start=0 + while start<len(columns_used) and not columns_used[start]: + start+=1 + stop=len(columns_used) + while stop>0 and not columns_used[stop-1]: + stop-=1 + step=0 + i=start + while i<stop: + j=i+1 + while not columns_used[j] and j<stop: + j+=1 + if step: + if step!=j-i: + mappable_to_one_slice = False + break + else: + step = j-i + if mappable_to_one_slice: + return data[slice(start,stop,step)] + # else make contiguous copy + n_columns = sum(columns_used) + result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype) + c=0 + for field_slice in self.fields.values(): + slice_width=field_slice.stop-field_slice.start + if field_slice.step: + slice_width /= field_slice.step + # copy the field here + result[:,slice(c,slice_width)]=self.data[field_slice] + c+=slice_width + return result + return self.data +