changeset 8:d1c394486037

Replaced asarray() method by __array__ method which gets called automatically when trying to cast into an array or numpy array or upon numpy.asarray(dataset).
author bengioy@bengiomac.local
date Mon, 24 Mar 2008 15:56:53 -0400
parents 6f8f338686db
children de616c423dbd
files _test_dataset.py dataset.py
diffstat 2 files changed, 12 insertions(+), 10 deletions(-) [+]
line wrap: on
line diff
--- a/_test_dataset.py	Mon Mar 24 13:20:15 2008 -0400
+++ b/_test_dataset.py	Mon Mar 24 15:56:53 2008 -0400
@@ -16,23 +16,25 @@
         a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1)
         s=0
         for example in a:
-            print len(example), example.x
             s+=_sum_all(example.x)
-        print s
+        #print s
         self.failUnless(abs(s-7.25967597)<1e-6)
 
     def test1(self):
         a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)},minibatch_size=1)
         a.minibatch_size=2
-        print a.asarray()
+        s=0
         for mb in a:
-            print mb,mb.asarray()
-        print "a.y=",a.y
+            s+=_sum_all(numpy.array(mb))
+        s+=a[3:6].x[1,1]
         for mb in ArrayDataSet(data=a.y,minibatch_size=2):
-            print mb
             for e in mb:
-                print e
-        self.failUnless(True)
+                s+=sum(e)
+        #print numpy.array(a)
+        #print a.y[4:9:2]
+        s+= _sum_all(a.y[4:9:2])
+        #print s
+        self.failUnless(abs(s-39.0334797)<1e-6)
         
 if __name__ == '__main__':
     unittest.main()
--- a/dataset.py	Mon Mar 24 13:20:15 2008 -0400
+++ b/dataset.py	Mon Mar 24 15:56:53 2008 -0400
@@ -107,7 +107,7 @@
     must correspond to a slice of columns. If the dataset has fields,
     each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array.
     Any dataset can also be converted to a numpy array (losing the notion of fields)
-    by the asarray(dataset) call.
+    by the numpy.array(dataset) call.
     """
 
     def __init__(self,dataset=None,data=None,fields={},minibatch_size=1):
@@ -187,7 +187,7 @@
         """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
         return ArrayDataSet(data=self.data[apply(slice,slice_args)],fields=self.fields)
 
-    def asarray(self):
+    def __array__(self):
         if not self.fields:
             return self.data
         # else, select subsets of columns mapped by the fields