Mercurial > pylearn
diff test_dataset.py @ 268:3f1cd8897fda
reverting dataset
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 04 Jun 2008 18:48:50 -0400 |
parents | 6e69fb91f3c0 |
children | fdce496c3b56 |
line wrap: on
line diff
--- a/test_dataset.py Wed Jun 04 17:49:28 2008 -0400 +++ b/test_dataset.py Wed Jun 04 18:48:50 2008 -0400 @@ -421,7 +421,7 @@ test_all(a2,ds) - del a2, ds #removes from list of active objects in debugger + del a2, ds def test_LookupList(): #test only the example in the doc??? @@ -544,8 +544,6 @@ f_array_iter(array) f_ds_index(ds) - f_ds_index(ds) - f_ds_iter(ds) f_ds_iter(ds) f_ds_mb1(ds,10) @@ -558,92 +556,8 @@ f_ds_mb2(ds,10000) - - - - -#**************************************************************** -# dummy tests, less powerful than the previous tests, but can work with any new weird dataset. -# Basically, emphasis is put on consistency, but it never checks the actual values. -# To be used as a checklist, or a first test, when creating a new dataset - -def dummytest_all(ds) : - """ Launches all the dummytests with a given dataset. """ - - dummytest1_basicstats(ds) - dummytest2_slicing(ds) - dummytest3_fields_iterator_consistency(ds) - - -def dummytest1_basicstats(ds) : - """print basics stats on a dataset, like length""" - - print 'len(ds) = ',len(ds) - print 'num fields = ', len(ds.fieldNames()) - print 'types of field: ', - for k in ds.fieldNames() : - print type(ds[0](k)[0]), - print '' - -def dummytest2_slicing(ds) : - """test if slicing seems to works properly""" - print 'testing slicing...', - sys.stdout.flush() - - middle = len(ds) / 2 - tenpercent = int(len(ds) * .1) - set1 = ds[:middle+tenpercent] - set2 = ds[middle-tenpercent:] - for k in range(tenpercent + tenpercent -1): - for k2 in ds.fieldNames() : - if type(set1[middle-tenpercent+k](k2)[0]) == N.ndarray : - for k3 in range(len(set1[middle-tenpercent+k](k2)[0])) : - assert set1[middle-tenpercent+k](k2)[0][k3] == set2[k](k2)[0][k3] - else : - assert set1[middle-tenpercent+k](k2)[0] == set2[k](k2)[0] - assert tenpercent > 1 - set3 = ds[middle-tenpercent:middle+tenpercent:2] - for k2 in ds.fieldNames() : - if type(set2[2](k2)[0]) == N.ndarray : - for k3 in range(len(set2[2](k2)[0])) : - assert set2[2](k2)[0][k3] == set3[1](k2)[0][k3] - else : - assert set2[2](k2)[0] == set3[1](k2)[0] - - print 'done' - - -def dummytest3_fields_iterator_consistency(ds) : - """test if the number of iterator corresponds to the number of fields, also do it for minibatches""" - print 'testing fields/iterator consistency...', - sys.stdout.flush() - - # basic test - maxsize = min(len(ds)-1,100) - for iter in ds[:maxsize] : - assert len(iter) == len(ds.fieldNames()) - if len(ds.fieldNames()) == 1 : - print 'done' - return - - # with minibatches iterator - ds2 = ds[:maxsize].minibatches([ds.fieldNames()[0],ds.fieldNames()[1]],minibatch_size=2) - for iter in ds2 : - assert len(iter) == 2 - - print 'done' - - - - - - - - - if __name__=='__main__': - if 0: - test1() + test1() test_LookupList() test_ArrayDataSet() test_CachedDataSet()