comparison test_dataset.py @ 229:d7250ee86f72

Added speed test for ArraDataSet
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Fri, 16 May 2008 16:40:26 -0400
parents b9950ae5e54b
children 4d1bd2513e06
comparison
equal deleted inserted replaced
228:6f55e301c687 229:d7250ee86f72
1 #!/bin/env python 1 #!/bin/env python
2 from dataset import * 2 from dataset import *
3 from math import * 3 from math import *
4 import numpy 4 import numpy
5 from misc import *
5 6
6 def have_raised(to_eval, **var): 7 def have_raised(to_eval, **var):
7 have_thrown = False 8 have_thrown = False
8 try: 9 try:
9 eval(to_eval) 10 eval(to_eval)
461 a = numpy.random.rand(10,4) 462 a = numpy.random.rand(10,4)
462 a2 = a+1 463 a2 = a+1
463 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested 464 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
464 465
465 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) 466 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False)
466 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True) 467 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1),
468 ['x','y','z'],
469 minibatch_mode=True)
467 470
468 test_all(a2,ds2) 471 test_all(a2,ds2)
469 test_all(a2,ds3) 472 test_all(a2,ds3)
470 473
471 del a,ds1,ds2,ds3 474 del a,ds1,ds2,ds3
483 print "test_VStackedDataSet" 486 print "test_VStackedDataSet"
484 raise NotImplementedError() 487 raise NotImplementedError()
485 def test_ArrayFieldsDataSet(): 488 def test_ArrayFieldsDataSet():
486 print "test_ArrayFieldsDataSet" 489 print "test_ArrayFieldsDataSet"
487 raise NotImplementedError() 490 raise NotImplementedError()
491
492
493 def test_speed():
494 print "test_speed"
495 import time
496 a2 = numpy.random.rand(100000,400)
497 ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
498 ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
499 ds = ArrayDataSet(a2,{all:slice(0,a2.shape[1],1)})
500 #assert ds==a? should this work?
501 @print_timing
502 def f_array1(a):
503 a+1
504 @print_timing
505 def f_array2(a):
506 for id in range(a.shape[0]):
507 pass
508 # a[id]+1
509 @print_timing
510 def f_ds(ds):
511 for ex in ds:
512 pass
513 # ex[0]+1
514 @print_timing
515 def f_ds_mb1(ds,mb_size):
516 for exs in ds.minibatches(minibatch_size = mb_size):
517 for ex in exs:
518 pass
519 # ex[0]+1
520 @print_timing
521 def f_ds_mb2(ds,mb_size):
522 for exs in ds.minibatches(minibatch_size = mb_size):
523 pass
524 # exs[0]+1
525
526 f_array1(a2)
527 f_array2(a2)
528
529 f_ds(ds)
530
531 f_ds_mb1(ds,10)
532 f_ds_mb1(ds,100)
533 f_ds_mb1(ds,1000)
534 f_ds_mb1(ds,10000)
535 f_ds_mb2(ds,10)
536 f_ds_mb2(ds,100)
537 f_ds_mb2(ds,1000)
538 f_ds_mb2(ds,10000)
539
540 del a2, ds
541
488 if __name__=='__main__': 542 if __name__=='__main__':
489 test1() 543 test1()
490 test_LookupList() 544 test_LookupList()
491 test_ArrayDataSet() 545 test_ArrayDataSet()
492 test_CachedDataSet() 546 test_CachedDataSet()
493 test_ApplyFunctionDataSet() 547 test_ApplyFunctionDataSet()
548 #test_speed()
494 #test pmat.py 549 #test pmat.py
550