Mercurial > pylearn
view _test_dataset.py @ 221:58e17421c69c
tester on iterator consistency now triggers a bug in dataset, linked to the combination of minibatch and slicing
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 23 May 2008 14:07:53 -0400 |
parents | 1f527fe65e22 |
children | 174374d59405 |
line wrap: on
line source
from dataset import * from math import * import unittest import sys import numpy as N def _sum_all(a): s=a while isinstance(s,numpy.ndarray): s=sum(s) return s class T_arraydataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test_ctor_len(self): n = numpy.random.rand(8,3) a=ArrayDataSet(n) self.failUnless(a.data is n) self.failUnless(a.fields is None) self.failUnless(len(a) == n.shape[0]) self.failUnless(a[0].shape == (n.shape[1],)) def test_iter(self): arr = numpy.random.rand(8,3) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) for i, example in enumerate(a): self.failUnless(numpy.all( example['x'] == arr[i,:2])) self.failUnless(numpy.all( example['y'] == arr[i,1:3])) def test_zip(self): arr = numpy.random.rand(8,3) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) for i, x in enumerate(a.zip("x")): self.failUnless(numpy.all( x == arr[i,:2])) def test_minibatch_basic(self): arr = numpy.random.rand(10,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields self.failUnless(numpy.all( mb['x'] == arr[i*2:i*2+2,0:2])) self.failUnless(numpy.all( mb['y'] == arr[i*2:i*2+2,1:4])) def test_getattr(self): arr = numpy.random.rand(10,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) a_y = a.y self.failUnless(numpy.all( a_y == arr[:,1:4])) def test_minibatch_wraparound_even(self): arr = numpy.random.rand(10,4) arr2 = ArrayDataSet.Iterator.matcat(arr,arr) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) #print arr for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)): #print 'x' , x self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2])) def test_minibatch_wraparound_odd(self): arr = numpy.random.rand(10,4) arr2 = ArrayDataSet.Iterator.matcat(arr,arr) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)): self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2])) class T_renamingdataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test_hasfield(self): n = numpy.random.rand(3,8) a=ArrayDataSet(data=n,fields={"x":slice(2),"y":slice(1,4),"z":slice(4,6)}) b=a.rename({'xx':'x','zz':'z'}) self.failUnless(b.hasFields('xx','zz') and not b.hasFields('x') and not b.hasFields('y')) class T_applyfunctiondataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test_function(self): n = numpy.random.rand(3,8) a=ArrayDataSet(data=n,fields={"x":slice(2),"y":slice(1,4),"z":slice(4,6)}) b=a.apply_function(lambda x,y: x+y,x+1, ['x','y'], ['x+y','x+1'], False,False,False) print b.fieldNames() print b('x+y') # to be used with a any new dataset class T_dataset_tester(object): """ This class' goal is to test any new dataset that is created Tests are (will be!) designed to check the normal behaviours of a dataset, as defined in dataset.py """ def __init__(self,ds,runall=True) : """if interested in only a subset of test, init with runall=False""" self.ds = ds if runall : self.test1_basicstats(ds) self.test2_slicing(ds) self.test3_fields_iterator_consistency(ds) def test1_basicstats(self,ds) : """print basics stats on a dataset, like length""" print 'len(ds) = ',len(ds) print 'num fields = ', len(ds.fieldNames()) print 'types of field: ', for k in ds.fieldNames() : print type(ds[0](k)[0]), print '' def test2_slicing(self,ds) : """test if slicing works properly""" print 'testing slicing...', sys.stdout.flush() middle = len(ds) / 2 tenpercent = int(len(ds) * .1) set1 = ds[:middle+tenpercent] set2 = ds[middle-tenpercent:] for k in range(tenpercent + tenpercent -1): for k2 in ds.fieldNames() : if type(set1[middle-tenpercent+k](k2)[0]) == N.ndarray : for k3 in range(len(set1[middle-tenpercent+k](k2)[0])) : assert set1[middle-tenpercent+k](k2)[0][k3] == set2[k](k2)[0][k3] else : assert set1[middle-tenpercent+k](k2)[0] == set2[k](k2)[0] assert tenpercent > 1 set3 = ds[middle-tenpercent:middle+tenpercent:2] for k2 in ds.fieldNames() : if type(set2[2](k2)[0]) == N.ndarray : for k3 in range(len(set2[2](k2)[0])) : assert set2[2](k2)[0][k3] == set3[1](k2)[0][k3] else : assert set2[2](k2)[0] == set3[1](k2)[0] print 'done' def test3_fields_iterator_consistency(self,ds) : """ check if the number of iterator corresponds to the number of fields""" print 'testing fields/iterator consistency...', sys.stdout.flush() # basic test maxsize = min(len(ds)-1,100) for iter in ds[:maxsize] : assert len(iter) == len(ds.fieldNames()) if len(ds.fieldNames()) == 1 : print 'done' return # with minibatches iterator ds2 = ds.minibatches[:maxsize]([ds.fieldNames()[0],ds.fieldNames()[1]],minibatch_size=2) for iter in ds2 : assert len(iter) == 2 print 'done' ################################################################### # main if __name__ == '__main__': unittest.main()