Mercurial > pylearn
changeset 4:f7dcfb5f9d5b
Added test for dataset.
author | bengioy@bengiomac.local |
---|---|
date | Sun, 23 Mar 2008 22:14:10 -0400 |
parents | 378b68d5c4ad |
children | 8039918516fe |
files | _test_dataset.py dataset.py |
diffstat | 2 files changed, 54 insertions(+), 15 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/_test_dataset.py Sun Mar 23 22:14:10 2008 -0400 @@ -0,0 +1,25 @@ +from dataset import * +from math import * +import unittest + +def _sum_all(a): + s=a + while isinstance(s,numpy.ndarray): + s=sum(s) + return s + +class T_arraydataset(unittest.TestCase): + def setUp(self): + numpy.random.seed(123456) + + def test0(self): + a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)}) + s=0 + for example in a: + s+=_sum_all(example.x) + print s + self.failUnless(abs(s-11.4674133)<1e-6) + +if __name__ == '__main__': + unittest.main() +
--- a/dataset.py Sun Mar 23 14:41:22 2008 -0400 +++ b/dataset.py Sun Mar 23 22:14:10 2008 -0400 @@ -58,7 +58,7 @@ # we may want ArrayDataSet defined in another python file -from numpy import * +import numpy class ArrayDataSet(FiniteDataSet): """ @@ -70,41 +70,57 @@ by the asarray(dataset) call. """ - def __self__(self,dataset=None,data=None,fields={}): + def __init__(self,dataset=None,data=None,fields={}): """ Construct an ArrayDataSet, either from a DataSet, or from a numpy.array plus an optional specification of fields (by a dictionary of column slices indexed by field names). """ self.current_row=-1 # used for view of this dataset as an iterator - if dataset: + if dataset!=None: assert data==None and fields=={} # convert dataset to an ArrayDataSet raise NotImplementedError - if data: + if data!=None: assert dataset==None self.data=data self.fields=fields self.width = data.shape[1] for fieldname in fields: fieldslice=fields[fieldname] - assert fieldslice.start>=0 and fieldslice.stop<=width) + # make sure fieldslice.start and fieldslice.step are defined + start=fieldslice.start + step=fieldslice.step + if not start: + start=0 + if not step: + step=1 + if not fieldslice.start or not fieldslice.step: + fieldslice = slice(start,fieldslice.stop,step) + # and coherent with the data array + assert fieldslice.start>=0 and fieldslice.stop<=self.width def next(self): - """Return the next example in the dataset. If the dataset has fields, - the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.""" - if fields: + """ + Return the next example in the dataset. If the dataset has fields, + the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. + """ + if self.fields: self.current_row+=1 if self.current_row==len(self.data): - self.current_row=0 + self.current_row=-1 + raise StopIteration return self[self.current_row] else: return self.data[self.current_row] def __getattr__(self,fieldname): """Return a sub-dataset containing only the given fieldname as field.""" - data = self.fields[fieldname] - return ArrayDataSet(data=data) + data=self.data[self.fields[fieldname]] + if len(data)==1: + return data + else: + return ArrayDataSet(data=data) def __call__(self,*fieldnames): """Return a sub-dataset containing only the given fieldnames as fields.""" @@ -144,7 +160,7 @@ def asarray(self): if self.fields: - columns_used = zeros((self.data.shape[1]),dtype=bool) + columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) for field_slice in self.fields.values(): for c in xrange(field_slice.start,field_slice.stop,field_slice.step): columns_used[c]=True @@ -175,9 +191,7 @@ result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype) c=0 for field_slice in self.fields.values(): - slice_width=field_slice.stop-field_slice.start - if field_slice.step: - slice_width /= field_slice.step + slice_width=field_slice.stop-field_slice.start/field_slice.step # copy the field here result[:,slice(c,slice_width)]=self.data[field_slice] c+=slice_width