Mercurial > pylearn
comparison test_dataset.py @ 229:d7250ee86f72
Added speed test for ArraDataSet
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Fri, 16 May 2008 16:40:26 -0400 |
parents | b9950ae5e54b |
children | 4d1bd2513e06 |
comparison
equal
deleted
inserted
replaced
228:6f55e301c687 | 229:d7250ee86f72 |
---|---|
1 #!/bin/env python | 1 #!/bin/env python |
2 from dataset import * | 2 from dataset import * |
3 from math import * | 3 from math import * |
4 import numpy | 4 import numpy |
5 from misc import * | |
5 | 6 |
6 def have_raised(to_eval, **var): | 7 def have_raised(to_eval, **var): |
7 have_thrown = False | 8 have_thrown = False |
8 try: | 9 try: |
9 eval(to_eval) | 10 eval(to_eval) |
461 a = numpy.random.rand(10,4) | 462 a = numpy.random.rand(10,4) |
462 a2 = a+1 | 463 a2 = a+1 |
463 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | 464 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested |
464 | 465 |
465 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) | 466 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) |
466 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True) | 467 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), |
468 ['x','y','z'], | |
469 minibatch_mode=True) | |
467 | 470 |
468 test_all(a2,ds2) | 471 test_all(a2,ds2) |
469 test_all(a2,ds3) | 472 test_all(a2,ds3) |
470 | 473 |
471 del a,ds1,ds2,ds3 | 474 del a,ds1,ds2,ds3 |
483 print "test_VStackedDataSet" | 486 print "test_VStackedDataSet" |
484 raise NotImplementedError() | 487 raise NotImplementedError() |
485 def test_ArrayFieldsDataSet(): | 488 def test_ArrayFieldsDataSet(): |
486 print "test_ArrayFieldsDataSet" | 489 print "test_ArrayFieldsDataSet" |
487 raise NotImplementedError() | 490 raise NotImplementedError() |
491 | |
492 | |
493 def test_speed(): | |
494 print "test_speed" | |
495 import time | |
496 a2 = numpy.random.rand(100000,400) | |
497 ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested | |
498 ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | |
499 ds = ArrayDataSet(a2,{all:slice(0,a2.shape[1],1)}) | |
500 #assert ds==a? should this work? | |
501 @print_timing | |
502 def f_array1(a): | |
503 a+1 | |
504 @print_timing | |
505 def f_array2(a): | |
506 for id in range(a.shape[0]): | |
507 pass | |
508 # a[id]+1 | |
509 @print_timing | |
510 def f_ds(ds): | |
511 for ex in ds: | |
512 pass | |
513 # ex[0]+1 | |
514 @print_timing | |
515 def f_ds_mb1(ds,mb_size): | |
516 for exs in ds.minibatches(minibatch_size = mb_size): | |
517 for ex in exs: | |
518 pass | |
519 # ex[0]+1 | |
520 @print_timing | |
521 def f_ds_mb2(ds,mb_size): | |
522 for exs in ds.minibatches(minibatch_size = mb_size): | |
523 pass | |
524 # exs[0]+1 | |
525 | |
526 f_array1(a2) | |
527 f_array2(a2) | |
528 | |
529 f_ds(ds) | |
530 | |
531 f_ds_mb1(ds,10) | |
532 f_ds_mb1(ds,100) | |
533 f_ds_mb1(ds,1000) | |
534 f_ds_mb1(ds,10000) | |
535 f_ds_mb2(ds,10) | |
536 f_ds_mb2(ds,100) | |
537 f_ds_mb2(ds,1000) | |
538 f_ds_mb2(ds,10000) | |
539 | |
540 del a2, ds | |
541 | |
488 if __name__=='__main__': | 542 if __name__=='__main__': |
489 test1() | 543 test1() |
490 test_LookupList() | 544 test_LookupList() |
491 test_ArrayDataSet() | 545 test_ArrayDataSet() |
492 test_CachedDataSet() | 546 test_CachedDataSet() |
493 test_ApplyFunctionDataSet() | 547 test_ApplyFunctionDataSet() |
548 #test_speed() | |
494 #test pmat.py | 549 #test pmat.py |
550 |