Mercurial > pylearn
comparison test_dataset.py @ 207:c5a7105fa40b
trying to merge
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Fri, 16 May 2008 16:38:15 -0400 |
parents | b9950ae5e54b |
children | 3c96d23b36ac d7250ee86f72 |
comparison
equal
deleted
inserted
replaced
206:f2ddc795ec49 | 207:c5a7105fa40b |
---|---|
392 assert len(ds('x','y').fields()) == 2 | 392 assert len(ds('x','y').fields()) == 2 |
393 assert len(ds('x','z').fields()) == 2 | 393 assert len(ds('x','z').fields()) == 2 |
394 assert len(ds('y').fields()) == 1 | 394 assert len(ds('y').fields()) == 1 |
395 | 395 |
396 del field | 396 del field |
397 def test_all(array,ds): | |
398 assert len(ds)==10 | |
399 | |
400 test_iterate_over_examples(array, ds) | |
401 test_getitem(array, ds) | |
402 test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) | |
403 test_fields_fct(ds) | |
397 | 404 |
398 def test_ArrayDataSet(): | 405 def test_ArrayDataSet(): |
399 #don't test stream | 406 #don't test stream |
400 #tested only with float value | 407 #tested only with float value |
401 #don't always test with y | 408 #don't always test with y |
404 #don't test proterties | 411 #don't test proterties |
405 print "test_ArrayDataSet" | 412 print "test_ArrayDataSet" |
406 a2 = numpy.random.rand(10,4) | 413 a2 = numpy.random.rand(10,4) |
407 ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested | 414 ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested |
408 ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | 415 ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested |
409 assert len(ds)==10 | |
410 #assert ds==a? should this work? | 416 #assert ds==a? should this work? |
411 | 417 |
412 test_iterate_over_examples(a2, ds) | 418 test_all(a2,ds) |
413 test_getitem(a2, ds) | |
414 test_ds_iterator(a2,ds('x','y'),ds('y','z'),ds('x','y','z')) | |
415 test_fields_fct(ds) | |
416 | 419 |
417 del a2, ds | 420 del a2, ds |
418 | 421 |
419 def test_LookupList(): | 422 def test_LookupList(): |
420 #test only the example in the doc??? | 423 #test only the example in the doc??? |
440 print "test_CacheDataSet" | 443 print "test_CacheDataSet" |
441 a = numpy.random.rand(10,4) | 444 a = numpy.random.rand(10,4) |
442 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | 445 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested |
443 ds2 = CachedDataSet(ds1) | 446 ds2 = CachedDataSet(ds1) |
444 ds3 = CachedDataSet(ds1,cache_all_upon_construction=True) | 447 ds3 = CachedDataSet(ds1,cache_all_upon_construction=True) |
445 assert len(ds2)==10 | 448 |
446 assert len(ds3)==10 | 449 test_all(a,ds2) |
447 | 450 test_all(a,ds3) |
448 test_iterate_over_examples(a, ds2) | |
449 test_getitem(a, ds2) | |
450 test_ds_iterator(a,ds2('x','y'),ds2('y','z'),ds2('x','y','z')) | |
451 test_fields_fct(ds2) | |
452 | |
453 test_iterate_over_examples(a, ds3) | |
454 test_getitem(a, ds3) | |
455 test_ds_iterator(a,ds3('x','y'),ds3('y','z'),ds3('x','y','z')) | |
456 test_fields_fct(ds3) | |
457 | 451 |
458 del a,ds1,ds2,ds3 | 452 del a,ds1,ds2,ds3 |
459 | 453 |
460 | 454 |
461 def test_DataSetFields(): | 455 def test_DataSetFields(): |
462 print "test_DataSetFields" | 456 print "test_DataSetFields" |
463 raise NotImplementedError() | 457 raise NotImplementedError() |
464 | 458 |
465 def test_ApplyFunctionDataSet(): | 459 def test_ApplyFunctionDataSet(): |
466 print "test_ApplyFunctionDataSet" | 460 print "test_ApplyFunctionDataSet" |
467 raise NotImplementedError() | 461 a = numpy.random.rand(10,4) |
462 a2 = a+1 | |
463 ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | |
464 | |
465 ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) | |
466 ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True) | |
467 | |
468 test_all(a2,ds2) | |
469 test_all(a2,ds3) | |
470 | |
471 del a,ds1,ds2,ds3 | |
472 | |
468 def test_FieldsSubsetDataSet(): | 473 def test_FieldsSubsetDataSet(): |
469 print "test_FieldsSubsetDataSet" | 474 print "test_FieldsSubsetDataSet" |
470 raise NotImplementedError() | 475 raise NotImplementedError() |
471 def test_MinibatchDataSet(): | 476 def test_MinibatchDataSet(): |
472 print "test_MinibatchDataSet" | 477 print "test_MinibatchDataSet" |
483 if __name__=='__main__': | 488 if __name__=='__main__': |
484 test1() | 489 test1() |
485 test_LookupList() | 490 test_LookupList() |
486 test_ArrayDataSet() | 491 test_ArrayDataSet() |
487 test_CachedDataSet() | 492 test_CachedDataSet() |
493 test_ApplyFunctionDataSet() | |
488 #test pmat.py | 494 #test pmat.py |