diff _test_dataset.py @ 17:759d17112b23

more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
author bergstrj@iro.umontreal.ca
date Wed, 26 Mar 2008 21:05:14 -0400
parents be128b9127c8
children 57f4015e2e09
line wrap: on
line diff
--- a/_test_dataset.py	Wed Mar 26 18:23:44 2008 -0400
+++ b/_test_dataset.py	Wed Mar 26 21:05:14 2008 -0400
@@ -12,28 +12,67 @@
     def setUp(self):
         numpy.random.seed(123456)
 
-    def test0(self):
-        a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)})
-        s=0
-        for example in a:
-            s+=_sum_all(example.x)
-        #print s
-        self.failUnless(abs(s-7.25967597)<1e-6)
+
+    def test_ctor_len(self):
+        n = numpy.random.rand(8,3)
+        a=ArrayDataSet(n)
+        self.failUnless(a.data is n)
+        self.failUnless(a.fields is None)
+
+        self.failUnless(len(a) == n.shape[0])
+        self.failUnless(a[0].shape == (n.shape[1],))
+
+    def test_iter(self):
+        arr = numpy.random.rand(8,3)
+        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
+        for i, example in enumerate(a):
+            self.failUnless(numpy.all( example.x == arr[i,:2]))
+            self.failUnless(numpy.all( example.y == arr[i,1:3]))
+
+    def test_zip(self):
+        arr = numpy.random.rand(8,3)
+        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
+        for i, x in enumerate(a.zip("x")):
+            self.failUnless(numpy.all( x == arr[i,:2]))
+
+    def test_minibatch_basic(self):
+        arr = numpy.random.rand(10,4)
+        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
+        for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields
+            self.failUnless(numpy.all( mb.x == arr[i*2:i*2+2,0:2]))
+            self.failUnless(numpy.all( mb.y == arr[i*2:i*2+2,1:4]))
 
-    def test1(self):
-        a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)})
-        s=0
-        for mb in a.minibatches(2):
-            s+=_sum_all(numpy.array(mb))
-        s+=a[3:6].x[1,1]
-        for mb in ArrayDataSet(data=a.y).minibatches(2):
-            for e in mb:
-                s+=sum(e)
-        #print numpy.array(a)
-        #print a.y[4:9:2]
-        s+= _sum_all(a.y[4:9:2])
-        #print s
-        self.failUnless(abs(s-39.0334797)<1e-6)
+    def test_getattr(self):
+        arr = numpy.random.rand(10,4)
+        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
+        a_y = a.y
+        self.failUnless(numpy.all( a_y == arr[:,1:4]))
+
+    def test_asarray(self):
+        arr = numpy.random.rand(3,4)
+        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
+        a_arr = numpy.asarray(a)
+        self.failUnless(a_arr.shape[1] == 2 + 3)
+
+    def test_minibatch_wraparound_even(self):
+        arr = numpy.random.rand(10,4)
+        arr2 = ArrayDataSet.Iterator.matcat(arr,arr)
+
+        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
+
+        #print arr
+        for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)):
+            #print 'x' , x
+            self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2]))
+
+    def test_minibatch_wraparound_odd(self):
+        arr = numpy.random.rand(10,4)
+        arr2 = ArrayDataSet.Iterator.matcat(arr,arr)
+
+        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
+
+        for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)):
+            self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2]))
         
 if __name__ == '__main__':
     unittest.main()