Mercurial > pylearn
comparison test_dataset.py @ 259:621faba17c60
created 'dummytests', tests that checks consistency of new weird datasets, where we can't compare with actual values in a matrix, for instance. Useful as a first debugging when creating a dataset
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 16:41:55 -0400 |
parents | bf0a1ebc6e52 |
children | 792f81d65f82 |
comparison
equal
deleted
inserted
replaced
258:19b14afe04b7 | 259:621faba17c60 |
---|---|
556 f_ds_mb2(ds,100) | 556 f_ds_mb2(ds,100) |
557 f_ds_mb2(ds,1000) | 557 f_ds_mb2(ds,1000) |
558 f_ds_mb2(ds,10000) | 558 f_ds_mb2(ds,10000) |
559 | 559 |
560 | 560 |
561 | |
562 | |
563 | |
564 | |
565 #**************************************************************** | |
566 # dummy tests, less powerful than the previous tests, but can work with any new weird dataset. | |
567 # Basically, emphasis is put on consistency, but it never checks the actual values. | |
568 # To be used as a checklist, or a first test, when creating a new dataset | |
569 | |
570 def dummytest_all(ds) : | |
571 """ Launches all the dummytests with a given dataset. """ | |
572 | |
573 | |
574 def test1_basicstats(self,ds) : | |
575 """print basics stats on a dataset, like length""" | |
576 | |
577 print 'len(ds) = ',len(ds) | |
578 print 'num fields = ', len(ds.fieldNames()) | |
579 print 'types of field: ', | |
580 for k in ds.fieldNames() : | |
581 print type(ds[0](k)[0]), | |
582 print '' | |
583 | |
584 def dummytest1_basicstats(self,ds) : | |
585 """print basics stats on a dataset, like length""" | |
586 | |
587 print 'len(ds) = ',len(ds) | |
588 print 'num fields = ', len(ds.fieldNames()) | |
589 print 'types of field: ', | |
590 for k in ds.fieldNames() : | |
591 print type(ds[0](k)[0]), | |
592 print '' | |
593 | |
594 def dummytest2_slicing(self,ds) : | |
595 """test if slicing works properly""" | |
596 print 'testing slicing...', | |
597 sys.stdout.flush() | |
598 | |
599 middle = len(ds) / 2 | |
600 tenpercent = int(len(ds) * .1) | |
601 set1 = ds[:middle+tenpercent] | |
602 set2 = ds[middle-tenpercent:] | |
603 for k in range(tenpercent + tenpercent -1): | |
604 for k2 in ds.fieldNames() : | |
605 if type(set1[middle-tenpercent+k](k2)[0]) == N.ndarray : | |
606 for k3 in range(len(set1[middle-tenpercent+k](k2)[0])) : | |
607 assert set1[middle-tenpercent+k](k2)[0][k3] == set2[k](k2)[0][k3] | |
608 else : | |
609 assert set1[middle-tenpercent+k](k2)[0] == set2[k](k2)[0] | |
610 assert tenpercent > 1 | |
611 set3 = ds[middle-tenpercent:middle+tenpercent:2] | |
612 for k2 in ds.fieldNames() : | |
613 if type(set2[2](k2)[0]) == N.ndarray : | |
614 for k3 in range(len(set2[2](k2)[0])) : | |
615 assert set2[2](k2)[0][k3] == set3[1](k2)[0][k3] | |
616 else : | |
617 assert set2[2](k2)[0] == set3[1](k2)[0] | |
618 | |
619 print 'done' | |
620 | |
621 | |
622 def dummytest3_fields_iterator_consistency(self,ds) : | |
623 """ check if the number of iterator corresponds to the number of fields""" | |
624 print 'testing fields/iterator consistency...', | |
625 sys.stdout.flush() | |
626 | |
627 # basic test | |
628 maxsize = min(len(ds)-1,100) | |
629 for iter in ds[:maxsize] : | |
630 assert len(iter) == len(ds.fieldNames()) | |
631 if len(ds.fieldNames()) == 1 : | |
632 print 'done' | |
633 return | |
634 | |
635 # with minibatches iterator | |
636 ds2 = ds[:maxsize].minibatches([ds.fieldNames()[0],ds.fieldNames()[1]],minibatch_size=2) | |
637 for iter in ds2 : | |
638 assert len(iter) == 2 | |
639 | |
640 print 'done' | |
641 | |
642 | |
643 | |
644 | |
645 | |
646 | |
647 | |
648 | |
649 | |
561 if __name__=='__main__': | 650 if __name__=='__main__': |
562 test1() | 651 test1() |
563 test_LookupList() | 652 test_LookupList() |
564 test_ArrayDataSet() | 653 test_ArrayDataSet() |
565 test_CachedDataSet() | 654 test_CachedDataSet() |