Mercurial > pylearn
diff _test_dataset.py @ 16:813723310d75
commenting
author | bergstrj@iro.umontreal.ca |
---|---|
date | Wed, 26 Mar 2008 18:23:44 -0400 |
parents | be128b9127c8 |
children | 759d17112b23 |
line wrap: on
line diff
--- a/_test_dataset.py Tue Mar 25 13:38:51 2008 -0400 +++ b/_test_dataset.py Wed Mar 26 18:23:44 2008 -0400 @@ -13,7 +13,7 @@ numpy.random.seed(123456) def test0(self): - a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1) + a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)}) s=0 for example in a: s+=_sum_all(example.x) @@ -21,13 +21,12 @@ self.failUnless(abs(s-7.25967597)<1e-6) def test1(self): - a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)},minibatch_size=1) - a.minibatch_size=2 + a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)}) s=0 - for mb in a: + for mb in a.minibatches(2): s+=_sum_all(numpy.array(mb)) s+=a[3:6].x[1,1] - for mb in ArrayDataSet(data=a.y,minibatch_size=2): + for mb in ArrayDataSet(data=a.y).minibatches(2): for e in mb: s+=sum(e) #print numpy.array(a)