Mercurial > pylearn
view _test_dataset.py @ 26:672fe4b23032
Fixed dataset errors so that _test_dataset.py works again.
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Fri, 11 Apr 2008 11:14:54 -0400 |
parents | b6b36f65664f |
children | 541a273bc89f |
line wrap: on
line source
from dataset import * from math import * import unittest def _sum_all(a): s=a while isinstance(s,numpy.ndarray): s=sum(s) return s class T_arraydataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test_ctor_len(self): n = numpy.random.rand(8,3) a=ArrayDataSet(n) self.failUnless(a.data is n) self.failUnless(a.fields is None) self.failUnless(len(a) == n.shape[0]) self.failUnless(a[0].shape == (n.shape[1],)) def test_iter(self): arr = numpy.random.rand(8,3) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) for i, example in enumerate(a): self.failUnless(numpy.all( example['x'] == arr[i,:2])) self.failUnless(numpy.all( example['y'] == arr[i,1:3])) def test_zip(self): arr = numpy.random.rand(8,3) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) for i, x in enumerate(a.zip("x")): self.failUnless(numpy.all( x == arr[i,:2])) def test_minibatch_basic(self): arr = numpy.random.rand(10,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields self.failUnless(numpy.all( mb['x'] == arr[i*2:i*2+2,0:2])) self.failUnless(numpy.all( mb['y'] == arr[i*2:i*2+2,1:4])) def test_getattr(self): arr = numpy.random.rand(10,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) a_y = a.y self.failUnless(numpy.all( a_y == arr[:,1:4])) def test_asarray(self): arr = numpy.random.rand(3,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(2,4)}) a_arr = numpy.asarray(a) self.failUnless(a_arr.shape[1] == 2 + 2) self.failUnless(numpy.sum(numpy.square(a_arr-a.data))==0) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) a_arr = numpy.asarray(a) self.failUnless(a_arr.shape[1] == 2 + 3) def test_minibatch_wraparound_even(self): arr = numpy.random.rand(10,4) arr2 = ArrayDataSet.Iterator.matcat(arr,arr) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) #print arr for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)): #print 'x' , x self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2])) def test_minibatch_wraparound_odd(self): arr = numpy.random.rand(10,4) arr2 = ArrayDataSet.Iterator.matcat(arr,arr) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)): self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2])) class T_renamingdataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test_hasfield(self): n = numpy.random.rand(3,8) a=ArrayDataSet(data=n,fields={"x":slice(2),"y":slice(1,4),"z":slice(4,6)}) b=a.rename({'xx':'x','zz':'z'}) self.failUnless(b.hasFields('xx','zz') and not b.hasFields('x') and not b.hasFields('y')) if __name__ == '__main__': unittest.main()