# HG changeset patch # User bengioy@bengiomac.local # Date 1206388613 14400 # Node ID d1c39448603788bba404fa6ab178490e2d38846f # Parent 6f8f338686db42cd171c41a0cbaa3a880e2bec8d Replaced asarray() method by __array__ method which gets called automatically when trying to cast into an array or numpy array or upon numpy.asarray(dataset). diff -r 6f8f338686db -r d1c394486037 _test_dataset.py --- a/_test_dataset.py Mon Mar 24 13:20:15 2008 -0400 +++ b/_test_dataset.py Mon Mar 24 15:56:53 2008 -0400 @@ -16,23 +16,25 @@ a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1) s=0 for example in a: - print len(example), example.x s+=_sum_all(example.x) - print s + #print s self.failUnless(abs(s-7.25967597)<1e-6) def test1(self): a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)},minibatch_size=1) a.minibatch_size=2 - print a.asarray() + s=0 for mb in a: - print mb,mb.asarray() - print "a.y=",a.y + s+=_sum_all(numpy.array(mb)) + s+=a[3:6].x[1,1] for mb in ArrayDataSet(data=a.y,minibatch_size=2): - print mb for e in mb: - print e - self.failUnless(True) + s+=sum(e) + #print numpy.array(a) + #print a.y[4:9:2] + s+= _sum_all(a.y[4:9:2]) + #print s + self.failUnless(abs(s-39.0334797)<1e-6) if __name__ == '__main__': unittest.main() diff -r 6f8f338686db -r d1c394486037 dataset.py --- a/dataset.py Mon Mar 24 13:20:15 2008 -0400 +++ b/dataset.py Mon Mar 24 15:56:53 2008 -0400 @@ -107,7 +107,7 @@ must correspond to a slice of columns. If the dataset has fields, each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array. Any dataset can also be converted to a numpy array (losing the notion of fields) - by the asarray(dataset) call. + by the numpy.array(dataset) call. """ def __init__(self,dataset=None,data=None,fields={},minibatch_size=1): @@ -187,7 +187,7 @@ """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" return ArrayDataSet(data=self.data[apply(slice,slice_args)],fields=self.fields) - def asarray(self): + def __array__(self): if not self.fields: return self.data # else, select subsets of columns mapped by the fields