comparison test_dataset.py @ 207:c5a7105fa40b

trying to merge
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Fri, 16 May 2008 16:38:15 -0400
parents b9950ae5e54b
children 3c96d23b36ac d7250ee86f72
comparison
equal deleted inserted replaced
206:f2ddc795ec49 207:c5a7105fa40b
392 assert len(ds('x','y').fields()) == 2 392 assert len(ds('x','y').fields()) == 2
393 assert len(ds('x','z').fields()) == 2 393 assert len(ds('x','z').fields()) == 2
394 assert len(ds('y').fields()) == 1 394 assert len(ds('y').fields()) == 1
395 395
396 del field 396 del field
397 def test_all(array,ds):
398 assert len(ds)==10
399
400 test_iterate_over_examples(array, ds)
401 test_getitem(array, ds)
402 test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z'))
403 test_fields_fct(ds)
397 404
398 def test_ArrayDataSet(): 405 def test_ArrayDataSet():
399 #don't test stream 406 #don't test stream
400 #tested only with float value 407 #tested only with float value
401 #don't always test with y 408 #don't always test with y
404 #don't test proterties 411 #don't test proterties
405 print "test_ArrayDataSet" 412 print "test_ArrayDataSet"
406 a2 = numpy.random.rand(10,4) 413 a2 = numpy.random.rand(10,4)
407 ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested 414 ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
408 ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested 415 ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
409 assert len(ds)==10
410 #assert ds==a? should this work? 416 #assert ds==a? should this work?
411 417
412 test_iterate_over_examples(a2, ds) 418 test_all(a2,ds)
413 test_getitem(a2, ds)
414 test_ds_iterator(a2,ds('x','y'),ds('y','z'),ds('x','y','z'))
415 test_fields_fct(ds)
416 419
417 del a2, ds 420 del a2, ds
418 421
419 def test_LookupList(): 422 def test_LookupList():
420 #test only the example in the doc??? 423 #test only the example in the doc???
440 print "test_CacheDataSet" 443 print "test_CacheDataSet"
441 a = numpy.random.rand(10,4) 444 a = numpy.random.rand(10,4)
442 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested 445 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
443 ds2 = CachedDataSet(ds1) 446 ds2 = CachedDataSet(ds1)
444 ds3 = CachedDataSet(ds1,cache_all_upon_construction=True) 447 ds3 = CachedDataSet(ds1,cache_all_upon_construction=True)
445 assert len(ds2)==10 448
446 assert len(ds3)==10 449 test_all(a,ds2)
447 450 test_all(a,ds3)
448 test_iterate_over_examples(a, ds2)
449 test_getitem(a, ds2)
450 test_ds_iterator(a,ds2('x','y'),ds2('y','z'),ds2('x','y','z'))
451 test_fields_fct(ds2)
452
453 test_iterate_over_examples(a, ds3)
454 test_getitem(a, ds3)
455 test_ds_iterator(a,ds3('x','y'),ds3('y','z'),ds3('x','y','z'))
456 test_fields_fct(ds3)
457 451
458 del a,ds1,ds2,ds3 452 del a,ds1,ds2,ds3
459 453
460 454
461 def test_DataSetFields(): 455 def test_DataSetFields():
462 print "test_DataSetFields" 456 print "test_DataSetFields"
463 raise NotImplementedError() 457 raise NotImplementedError()
464 458
465 def test_ApplyFunctionDataSet(): 459 def test_ApplyFunctionDataSet():
466 print "test_ApplyFunctionDataSet" 460 print "test_ApplyFunctionDataSet"
467 raise NotImplementedError() 461 a = numpy.random.rand(10,4)
462 a2 = a+1
463 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
464
465 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False)
466 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True)
467
468 test_all(a2,ds2)
469 test_all(a2,ds3)
470
471 del a,ds1,ds2,ds3
472
468 def test_FieldsSubsetDataSet(): 473 def test_FieldsSubsetDataSet():
469 print "test_FieldsSubsetDataSet" 474 print "test_FieldsSubsetDataSet"
470 raise NotImplementedError() 475 raise NotImplementedError()
471 def test_MinibatchDataSet(): 476 def test_MinibatchDataSet():
472 print "test_MinibatchDataSet" 477 print "test_MinibatchDataSet"
483 if __name__=='__main__': 488 if __name__=='__main__':
484 test1() 489 test1()
485 test_LookupList() 490 test_LookupList()
486 test_ArrayDataSet() 491 test_ArrayDataSet()
487 test_CachedDataSet() 492 test_CachedDataSet()
493 test_ApplyFunctionDataSet()
488 #test pmat.py 494 #test pmat.py