Mercurial > pylearn
diff _test_dataset.py @ 17:759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
author | bergstrj@iro.umontreal.ca |
---|---|
date | Wed, 26 Mar 2008 21:05:14 -0400 |
parents | be128b9127c8 |
children | 57f4015e2e09 |
line wrap: on
line diff
--- a/_test_dataset.py Wed Mar 26 18:23:44 2008 -0400 +++ b/_test_dataset.py Wed Mar 26 21:05:14 2008 -0400 @@ -12,28 +12,67 @@ def setUp(self): numpy.random.seed(123456) - def test0(self): - a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)}) - s=0 - for example in a: - s+=_sum_all(example.x) - #print s - self.failUnless(abs(s-7.25967597)<1e-6) + + def test_ctor_len(self): + n = numpy.random.rand(8,3) + a=ArrayDataSet(n) + self.failUnless(a.data is n) + self.failUnless(a.fields is None) + + self.failUnless(len(a) == n.shape[0]) + self.failUnless(a[0].shape == (n.shape[1],)) + + def test_iter(self): + arr = numpy.random.rand(8,3) + a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) + for i, example in enumerate(a): + self.failUnless(numpy.all( example.x == arr[i,:2])) + self.failUnless(numpy.all( example.y == arr[i,1:3])) + + def test_zip(self): + arr = numpy.random.rand(8,3) + a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) + for i, x in enumerate(a.zip("x")): + self.failUnless(numpy.all( x == arr[i,:2])) + + def test_minibatch_basic(self): + arr = numpy.random.rand(10,4) + a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) + for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields + self.failUnless(numpy.all( mb.x == arr[i*2:i*2+2,0:2])) + self.failUnless(numpy.all( mb.y == arr[i*2:i*2+2,1:4])) - def test1(self): - a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)}) - s=0 - for mb in a.minibatches(2): - s+=_sum_all(numpy.array(mb)) - s+=a[3:6].x[1,1] - for mb in ArrayDataSet(data=a.y).minibatches(2): - for e in mb: - s+=sum(e) - #print numpy.array(a) - #print a.y[4:9:2] - s+= _sum_all(a.y[4:9:2]) - #print s - self.failUnless(abs(s-39.0334797)<1e-6) + def test_getattr(self): + arr = numpy.random.rand(10,4) + a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) + a_y = a.y + self.failUnless(numpy.all( a_y == arr[:,1:4])) + + def test_asarray(self): + arr = numpy.random.rand(3,4) + a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) + a_arr = numpy.asarray(a) + self.failUnless(a_arr.shape[1] == 2 + 3) + + def test_minibatch_wraparound_even(self): + arr = numpy.random.rand(10,4) + arr2 = ArrayDataSet.Iterator.matcat(arr,arr) + + a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) + + #print arr + for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)): + #print 'x' , x + self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2])) + + def test_minibatch_wraparound_odd(self): + arr = numpy.random.rand(10,4) + arr2 = ArrayDataSet.Iterator.matcat(arr,arr) + + a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) + + for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)): + self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2])) if __name__ == '__main__': unittest.main()