diff _test_dataset.py @ 21:fdf0abc490f7

Adapted _test_dataset.py to changes in LookupList
author bengioy@bengiomac.local
date Mon, 07 Apr 2008 19:32:52 -0400
parents 57f4015e2e09
children b6b36f65664f
line wrap: on
line diff
--- a/_test_dataset.py	Mon Apr 07 09:48:39 2008 -0400
+++ b/_test_dataset.py	Mon Apr 07 19:32:52 2008 -0400
@@ -26,8 +26,8 @@
         arr = numpy.random.rand(8,3)
         a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
         for i, example in enumerate(a):
-            self.failUnless(numpy.all( example.x == arr[i,:2]))
-            self.failUnless(numpy.all( example.y == arr[i,1:3]))
+            self.failUnless(numpy.all( example['x'] == arr[i,:2]))
+            self.failUnless(numpy.all( example['y'] == arr[i,1:3]))
 
     def test_zip(self):
         arr = numpy.random.rand(8,3)
@@ -39,8 +39,8 @@
         arr = numpy.random.rand(10,4)
         a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
         for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields
-            self.failUnless(numpy.all( mb.x == arr[i*2:i*2+2,0:2]))
-            self.failUnless(numpy.all( mb.y == arr[i*2:i*2+2,1:4]))
+            self.failUnless(numpy.all( mb['x'] == arr[i*2:i*2+2,0:2]))
+            self.failUnless(numpy.all( mb['y'] == arr[i*2:i*2+2,1:4]))
 
     def test_getattr(self):
         arr = numpy.random.rand(10,4)
@@ -53,6 +53,7 @@
         a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
         a_arr = numpy.asarray(a)
         self.failUnless(a_arr.shape[1] == 2 + 3)
+        self.failUnless(a_arr == arr)
 
     def test_minibatch_wraparound_even(self):
         arr = numpy.random.rand(10,4)