Mercurial > pylearn
view _test_dataset.py @ 99:a8da709eb6a9
in ArrayDataSet.__init__ if a columns is an index, we change it to be a list that containt only this index. This way, we remove the special case where the columns is an index for all subsequent call.
This was possing trouble with numpy.vstack() called by MinibatchWrapAroundIterator.next
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 13:57:36 -0400 |
parents | 46c5c90019c2 |
children | 5b3afda2f1ad |
line wrap: on
line source
from dataset import * from math import * import unittest def _sum_all(a): s=a while isinstance(s,numpy.ndarray): s=sum(s) return s class T_arraydataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test_ctor_len(self): n = numpy.random.rand(8,3) a=ArrayDataSet(n) self.failUnless(a.data is n) self.failUnless(a.fields is None) self.failUnless(len(a) == n.shape[0]) self.failUnless(a[0].shape == (n.shape[1],)) def test_iter(self): arr = numpy.random.rand(8,3) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) for i, example in enumerate(a): self.failUnless(numpy.all( example['x'] == arr[i,:2])) self.failUnless(numpy.all( example['y'] == arr[i,1:3])) def test_zip(self): arr = numpy.random.rand(8,3) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) for i, x in enumerate(a.zip("x")): self.failUnless(numpy.all( x == arr[i,:2])) def test_minibatch_basic(self): arr = numpy.random.rand(10,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields self.failUnless(numpy.all( mb['x'] == arr[i*2:i*2+2,0:2])) self.failUnless(numpy.all( mb['y'] == arr[i*2:i*2+2,1:4])) def test_getattr(self): arr = numpy.random.rand(10,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) a_y = a.y self.failUnless(numpy.all( a_y == arr[:,1:4])) def test_minibatch_wraparound_even(self): arr = numpy.random.rand(10,4) arr2 = ArrayDataSet.Iterator.matcat(arr,arr) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) #print arr for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)): #print 'x' , x self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2])) def test_minibatch_wraparound_odd(self): arr = numpy.random.rand(10,4) arr2 = ArrayDataSet.Iterator.matcat(arr,arr) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)): self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2])) class T_renamingdataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test_hasfield(self): n = numpy.random.rand(3,8) a=ArrayDataSet(data=n,fields={"x":slice(2),"y":slice(1,4),"z":slice(4,6)}) b=a.rename({'xx':'x','zz':'z'}) self.failUnless(b.hasFields('xx','zz') and not b.hasFields('x') and not b.hasFields('y')) class T_applyfunctiondataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test_function(self): n = numpy.random.rand(3,8) a=ArrayDataSet(data=n,fields={"x":slice(2),"y":slice(1,4),"z":slice(4,6)}) b=a.apply_function(lambda x,y: x+y,x+1, ['x','y'], ['x+y','x+1'], False,False,False) print b.fieldNames() print b('x+y') if __name__ == '__main__': unittest.main()