Mercurial > pylearn
diff test_dataset.py @ 229:d7250ee86f72
Added speed test for ArraDataSet
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Fri, 16 May 2008 16:40:26 -0400 |
parents | b9950ae5e54b |
children | 4d1bd2513e06 |
line wrap: on
line diff
--- a/test_dataset.py Fri May 16 16:38:07 2008 -0400 +++ b/test_dataset.py Fri May 16 16:40:26 2008 -0400 @@ -2,6 +2,7 @@ from dataset import * from math import * import numpy +from misc import * def have_raised(to_eval, **var): have_thrown = False @@ -463,7 +464,9 @@ ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) - ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True) + ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), + ['x','y','z'], + minibatch_mode=True) test_all(a2,ds2) test_all(a2,ds3) @@ -485,10 +488,63 @@ def test_ArrayFieldsDataSet(): print "test_ArrayFieldsDataSet" raise NotImplementedError() + + +def test_speed(): + print "test_speed" + import time + a2 = numpy.random.rand(100000,400) + ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested + ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested + ds = ArrayDataSet(a2,{all:slice(0,a2.shape[1],1)}) + #assert ds==a? should this work? + @print_timing + def f_array1(a): + a+1 + @print_timing + def f_array2(a): + for id in range(a.shape[0]): + pass +# a[id]+1 + @print_timing + def f_ds(ds): + for ex in ds: + pass +# ex[0]+1 + @print_timing + def f_ds_mb1(ds,mb_size): + for exs in ds.minibatches(minibatch_size = mb_size): + for ex in exs: + pass +# ex[0]+1 + @print_timing + def f_ds_mb2(ds,mb_size): + for exs in ds.minibatches(minibatch_size = mb_size): + pass +# exs[0]+1 + + f_array1(a2) + f_array2(a2) + + f_ds(ds) + + f_ds_mb1(ds,10) + f_ds_mb1(ds,100) + f_ds_mb1(ds,1000) + f_ds_mb1(ds,10000) + f_ds_mb2(ds,10) + f_ds_mb2(ds,100) + f_ds_mb2(ds,1000) + f_ds_mb2(ds,10000) + + del a2, ds + if __name__=='__main__': test1() test_LookupList() test_ArrayDataSet() test_CachedDataSet() test_ApplyFunctionDataSet() + #test_speed() #test pmat.py +