Mercurial > pylearn
comparison test_dataset.py @ 231:38beb81f4e8b
Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 27 May 2008 13:46:03 -0400 |
parents | df3fae88ab46 4d1bd2513e06 |
children | 6aff510792dd |
comparison
equal
deleted
inserted
replaced
227:17c5d080964b | 231:38beb81f4e8b |
---|---|
1 #!/bin/env python | 1 #!/bin/env python |
2 from dataset import * | 2 from dataset import * |
3 from math import * | 3 from math import * |
4 import numpy | 4 import numpy |
5 from misc import * | |
5 | 6 |
6 def have_raised(to_eval, **var): | 7 def have_raised(to_eval, **var): |
7 have_thrown = False | 8 have_thrown = False |
8 try: | 9 try: |
9 eval(to_eval) | 10 eval(to_eval) |
461 a = numpy.random.rand(10,4) | 462 a = numpy.random.rand(10,4) |
462 a2 = a+1 | 463 a2 = a+1 |
463 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | 464 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested |
464 | 465 |
465 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) | 466 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) |
466 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True) | 467 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), |
468 ['x','y','z'], | |
469 minibatch_mode=True) | |
467 | 470 |
468 test_all(a2,ds2) | 471 test_all(a2,ds2) |
469 test_all(a2,ds3) | 472 test_all(a2,ds3) |
470 | 473 |
471 del a,ds1,ds2,ds3 | 474 del a,ds1,ds2,ds3 |
483 print "test_VStackedDataSet" | 486 print "test_VStackedDataSet" |
484 raise NotImplementedError() | 487 raise NotImplementedError() |
485 def test_ArrayFieldsDataSet(): | 488 def test_ArrayFieldsDataSet(): |
486 print "test_ArrayFieldsDataSet" | 489 print "test_ArrayFieldsDataSet" |
487 raise NotImplementedError() | 490 raise NotImplementedError() |
491 | |
492 | |
493 def test_speed(): | |
494 print "test_speed" | |
495 import time | |
496 a2 = numpy.random.rand(100000,400) | |
497 ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested | |
498 ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | |
499 ds = ArrayDataSet(a2,{'all':slice(0,a2.shape[1],1)}) | |
500 #assert ds==a? should this work? | |
501 mat = numpy.random.rand(400,100) | |
502 @print_timing | |
503 def f_array1(a): | |
504 a+1 | |
505 @print_timing | |
506 def f_array2(a): | |
507 for id in range(a.shape[0]): | |
508 # pass | |
509 a[id]+1 | |
510 # a[id]*mat | |
511 @print_timing | |
512 def f_ds(ds): | |
513 for ex in ds: | |
514 # pass | |
515 ex[0]+1 | |
516 # a[id]*mat | |
517 @print_timing | |
518 def f_ds_mb1(ds,mb_size): | |
519 for exs in ds.minibatches(minibatch_size = mb_size): | |
520 for ex in exs: | |
521 # pass | |
522 ex[0]+1 | |
523 # ex[id]*mat | |
524 @print_timing | |
525 def f_ds_mb2(ds,mb_size): | |
526 for exs in ds.minibatches(minibatch_size = mb_size): | |
527 # pass | |
528 exs[0]+1 | |
529 # ex[id]*mat | |
530 | |
531 f_array1(a2) | |
532 f_array2(a2) | |
533 | |
534 f_ds(ds) | |
535 | |
536 f_ds_mb1(ds,10) | |
537 f_ds_mb1(ds,100) | |
538 f_ds_mb1(ds,1000) | |
539 f_ds_mb1(ds,10000) | |
540 f_ds_mb2(ds,10) | |
541 f_ds_mb2(ds,100) | |
542 f_ds_mb2(ds,1000) | |
543 f_ds_mb2(ds,10000) | |
544 | |
545 del a2, ds | |
546 | |
488 if __name__=='__main__': | 547 if __name__=='__main__': |
489 test1() | 548 test1() |
490 test_LookupList() | 549 test_LookupList() |
491 test_ArrayDataSet() | 550 test_ArrayDataSet() |
492 test_CachedDataSet() | 551 test_CachedDataSet() |
493 test_ApplyFunctionDataSet() | 552 test_ApplyFunctionDataSet() |
553 #test_speed() | |
494 #test pmat.py | 554 #test pmat.py |
555 |