comparison test_dataset.py @ 231:38beb81f4e8b

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 27 May 2008 13:46:03 -0400
parents df3fae88ab46 4d1bd2513e06
children 6aff510792dd
comparison
equal deleted inserted replaced
227:17c5d080964b 231:38beb81f4e8b
1 #!/bin/env python 1 #!/bin/env python
2 from dataset import * 2 from dataset import *
3 from math import * 3 from math import *
4 import numpy 4 import numpy
5 from misc import *
5 6
6 def have_raised(to_eval, **var): 7 def have_raised(to_eval, **var):
7 have_thrown = False 8 have_thrown = False
8 try: 9 try:
9 eval(to_eval) 10 eval(to_eval)
461 a = numpy.random.rand(10,4) 462 a = numpy.random.rand(10,4)
462 a2 = a+1 463 a2 = a+1
463 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested 464 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
464 465
465 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) 466 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False)
466 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True) 467 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1),
468 ['x','y','z'],
469 minibatch_mode=True)
467 470
468 test_all(a2,ds2) 471 test_all(a2,ds2)
469 test_all(a2,ds3) 472 test_all(a2,ds3)
470 473
471 del a,ds1,ds2,ds3 474 del a,ds1,ds2,ds3
483 print "test_VStackedDataSet" 486 print "test_VStackedDataSet"
484 raise NotImplementedError() 487 raise NotImplementedError()
485 def test_ArrayFieldsDataSet(): 488 def test_ArrayFieldsDataSet():
486 print "test_ArrayFieldsDataSet" 489 print "test_ArrayFieldsDataSet"
487 raise NotImplementedError() 490 raise NotImplementedError()
491
492
493 def test_speed():
494 print "test_speed"
495 import time
496 a2 = numpy.random.rand(100000,400)
497 ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
498 ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
499 ds = ArrayDataSet(a2,{'all':slice(0,a2.shape[1],1)})
500 #assert ds==a? should this work?
501 mat = numpy.random.rand(400,100)
502 @print_timing
503 def f_array1(a):
504 a+1
505 @print_timing
506 def f_array2(a):
507 for id in range(a.shape[0]):
508 # pass
509 a[id]+1
510 # a[id]*mat
511 @print_timing
512 def f_ds(ds):
513 for ex in ds:
514 # pass
515 ex[0]+1
516 # a[id]*mat
517 @print_timing
518 def f_ds_mb1(ds,mb_size):
519 for exs in ds.minibatches(minibatch_size = mb_size):
520 for ex in exs:
521 # pass
522 ex[0]+1
523 # ex[id]*mat
524 @print_timing
525 def f_ds_mb2(ds,mb_size):
526 for exs in ds.minibatches(minibatch_size = mb_size):
527 # pass
528 exs[0]+1
529 # ex[id]*mat
530
531 f_array1(a2)
532 f_array2(a2)
533
534 f_ds(ds)
535
536 f_ds_mb1(ds,10)
537 f_ds_mb1(ds,100)
538 f_ds_mb1(ds,1000)
539 f_ds_mb1(ds,10000)
540 f_ds_mb2(ds,10)
541 f_ds_mb2(ds,100)
542 f_ds_mb2(ds,1000)
543 f_ds_mb2(ds,10000)
544
545 del a2, ds
546
488 if __name__=='__main__': 547 if __name__=='__main__':
489 test1() 548 test1()
490 test_LookupList() 549 test_LookupList()
491 test_ArrayDataSet() 550 test_ArrayDataSet()
492 test_CachedDataSet() 551 test_CachedDataSet()
493 test_ApplyFunctionDataSet() 552 test_ApplyFunctionDataSet()
553 #test_speed()
494 #test pmat.py 554 #test pmat.py
555