Mercurial > pylearn
diff _test_dataset.py @ 315:b48cf8dce2bf
test to compare overriden __getitem__ implemented, tested on ArrayDataSet.__getitem__
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Wed, 11 Jun 2008 16:26:41 -0400 |
parents | 96cca78de3ed |
children | 9c08e3af975e 4efb503fd0da |
line wrap: on
line diff
--- a/_test_dataset.py Wed Jun 11 13:57:34 2008 -0400 +++ b/_test_dataset.py Wed Jun 11 16:26:41 2008 -0400 @@ -267,7 +267,7 @@ assert (orig[index[i]]['x']==array[index[i]][:3]).all() assert (orig[index[i]]['x']==x).all() assert orig[index[i]]['y']==array[index[i]][3] - assert (orig[index[i]]['y']==y).all() + assert (orig[index[i]]['y']==y).all() # why does it crash sometimes? assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() assert (orig[index[i]]['z']==z).all() i+=1 @@ -375,14 +375,90 @@ assert len(ds('y').fields()) == 1 del field + +def test_overrides(ds) : + """ Test for examples that an override __getitem__ acts as the one in DataSet """ + def ndarray_list_equal(nda,l) : + """ + Compares if a ndarray is the same as the list. Do it by converting the list into + an numpy.ndarray, if possible + """ + try : + l = numpy.asmatrix(l) + except : + return False + return smart_equal(nda,l) + + def smart_equal(a1,a2) : + """ + Handles numpy.ndarray, LookupList, and basic containers + """ + if not isinstance(a1,type(a2)) and not isinstance(a2,type(a1)): + #special case: matrix vs list of arrays + if isinstance(a1,numpy.ndarray) : + return ndarray_list_equal(a1,a2) + elif isinstance(a2,numpy.ndarray) : + return ndarray_list_equal(a2,a1) + return False + # compares 2 numpy.ndarray + if isinstance(a1,numpy.ndarray): + if len(a1.shape) != len(a2.shape): + return False + for k in range(len(a1.shape)) : + if a1.shape[k] != a2.shape[k]: + return False + return (a1==a2).all() + # compares 2 lookuplists + if isinstance(a1,LookupList) : + if len(a1._names) != len(a2._names) : + return False + for k in a1._names : + if k not in a2._names : + return False + if not smart_equal(a1[k],a2[k]) : + return False + return True + # compares 2 basic containers + if hasattr(a1,'__len__'): + if len(a1) != len(a2) : + return False + for k in range(len(a1)) : + if not smart_equal(a1[k],a2[k]): + return False + return True + # try basic equals + return a1 is a2 + + def mask(ds) : + class TestOverride(type(ds)): + def __init__(self,ds) : + self.ds = ds + def __getitem__(self,key) : + res1 = self.ds[key] + res2 = DataSet.__getitem__(ds,key) + assert smart_equal(res1,res2) + return res1 + return TestOverride(ds) + # test getitem + ds2 = mask(ds) + for k in range(10): + res = ds2[k] + res = ds2[1:len(ds):3] + + + + + + def test_all(array,ds): assert len(ds)==10 - test_iterate_over_examples(array, ds) + test_overrides(ds) test_getitem(array, ds) test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) test_fields_fct(ds) + class T_DataSet(unittest.TestCase): def test_ArrayDataSet(self): #don't test stream