diff _test_dataset.py @ 11:be128b9127c8

Debugged (to the extent of my tests) the new version of dataset
author bengioy@esprit.iro.umontreal.ca
date Wed, 26 Mar 2008 15:01:30 -0400
parents d1c394486037
children 759d17112b23
line wrap: on
line diff
--- a/_test_dataset.py	Tue Mar 25 11:39:02 2008 -0400
+++ b/_test_dataset.py	Wed Mar 26 15:01:30 2008 -0400
@@ -13,7 +13,7 @@
         numpy.random.seed(123456)
 
     def test0(self):
-        a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1)
+        a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)})
         s=0
         for example in a:
             s+=_sum_all(example.x)
@@ -21,13 +21,12 @@
         self.failUnless(abs(s-7.25967597)<1e-6)
 
     def test1(self):
-        a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)},minibatch_size=1)
-        a.minibatch_size=2
+        a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)})
         s=0
-        for mb in a:
+        for mb in a.minibatches(2):
             s+=_sum_all(numpy.array(mb))
         s+=a[3:6].x[1,1]
-        for mb in ArrayDataSet(data=a.y,minibatch_size=2):
+        for mb in ArrayDataSet(data=a.y).minibatches(2):
             for e in mb:
                 s+=sum(e)
         #print numpy.array(a)