Mercurial > pylearn
diff _test_dataset.py @ 221:58e17421c69c
tester on iterator consistency now triggers a bug in dataset, linked to the combination of minibatch and slicing
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 23 May 2008 14:07:53 -0400 |
parents | 1f527fe65e22 |
children | 174374d59405 |
line wrap: on
line diff
--- a/_test_dataset.py Fri May 23 13:44:25 2008 -0400 +++ b/_test_dataset.py Fri May 23 14:07:53 2008 -0400 @@ -112,6 +112,7 @@ if runall : self.test1_basicstats(ds) self.test2_slicing(ds) + self.test3_fields_iterator_consistency(ds) def test1_basicstats(self,ds) : """print basics stats on a dataset, like length""" @@ -139,10 +140,42 @@ assert set1[middle-tenpercent+k](k2)[0][k3] == set2[k](k2)[0][k3] else : assert set1[middle-tenpercent+k](k2)[0] == set2[k](k2)[0] + assert tenpercent > 1 + set3 = ds[middle-tenpercent:middle+tenpercent:2] + for k2 in ds.fieldNames() : + if type(set2[2](k2)[0]) == N.ndarray : + for k3 in range(len(set2[2](k2)[0])) : + assert set2[2](k2)[0][k3] == set3[1](k2)[0][k3] + else : + assert set2[2](k2)[0] == set3[1](k2)[0] print 'done' + def test3_fields_iterator_consistency(self,ds) : + """ check if the number of iterator corresponds to the number of fields""" + print 'testing fields/iterator consistency...', + sys.stdout.flush() + + # basic test + maxsize = min(len(ds)-1,100) + for iter in ds[:maxsize] : + assert len(iter) == len(ds.fieldNames()) + if len(ds.fieldNames()) == 1 : + print 'done' + return + + # with minibatches iterator + ds2 = ds.minibatches[:maxsize]([ds.fieldNames()[0],ds.fieldNames()[1]],minibatch_size=2) + for iter in ds2 : + assert len(iter) == 2 + + print 'done' + + + + + ################################################################### # main if __name__ == '__main__':