diff _test_dataset.py @ 315:b48cf8dce2bf

test to compare overriden __getitem__ implemented, tested on ArrayDataSet.__getitem__
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Wed, 11 Jun 2008 16:26:41 -0400
parents 96cca78de3ed
children 9c08e3af975e 4efb503fd0da
line wrap: on
line diff
--- a/_test_dataset.py	Wed Jun 11 13:57:34 2008 -0400
+++ b/_test_dataset.py	Wed Jun 11 16:26:41 2008 -0400
@@ -267,7 +267,7 @@
             assert (orig[index[i]]['x']==array[index[i]][:3]).all()
             assert (orig[index[i]]['x']==x).all()
             assert orig[index[i]]['y']==array[index[i]][3]
-            assert (orig[index[i]]['y']==y).all()
+            assert (orig[index[i]]['y']==y).all() # why does it crash sometimes?
             assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all()
             assert (orig[index[i]]['z']==z).all()
             i+=1
@@ -375,14 +375,90 @@
     assert len(ds('y').fields()) == 1
 
     del field
+
+def test_overrides(ds) :
+    """ Test for examples that an override __getitem__ acts as the one in DataSet """
+    def ndarray_list_equal(nda,l) :
+        """ 
+        Compares if a ndarray is the same as the list. Do it by converting the list into
+        an numpy.ndarray, if possible
+        """
+        try :
+            l = numpy.asmatrix(l)
+        except :
+            return False
+        return smart_equal(nda,l)
+        
+    def smart_equal(a1,a2) :
+        """
+        Handles numpy.ndarray, LookupList, and basic containers
+        """
+        if not isinstance(a1,type(a2)) and not isinstance(a2,type(a1)):
+            #special case: matrix vs list of arrays
+            if isinstance(a1,numpy.ndarray) :
+                return ndarray_list_equal(a1,a2)
+            elif isinstance(a2,numpy.ndarray) :
+                return ndarray_list_equal(a2,a1)
+            return False
+        # compares 2 numpy.ndarray
+        if isinstance(a1,numpy.ndarray):
+            if len(a1.shape) != len(a2.shape):
+                return False
+            for k in range(len(a1.shape)) :
+                if a1.shape[k] != a2.shape[k]:
+                    return False
+            return (a1==a2).all()
+        # compares 2 lookuplists
+        if isinstance(a1,LookupList) :
+            if len(a1._names) != len(a2._names) :
+                return False
+            for k in a1._names :
+                if k not in a2._names :
+                    return False
+                if not smart_equal(a1[k],a2[k]) :
+                    return False
+            return True
+        # compares 2 basic containers
+        if hasattr(a1,'__len__'):
+            if len(a1) != len(a2) :
+                return False
+            for k in range(len(a1)) :
+                if not smart_equal(a1[k],a2[k]):
+                    return False
+            return True
+        # try basic equals
+        return a1 is a2
+
+    def mask(ds) :
+        class TestOverride(type(ds)):
+            def __init__(self,ds) :
+                self.ds = ds
+            def __getitem__(self,key) :
+                res1 = self.ds[key]
+                res2 = DataSet.__getitem__(ds,key)
+                assert smart_equal(res1,res2)
+                return res1
+        return TestOverride(ds)
+    # test getitem
+    ds2 = mask(ds)
+    for k in range(10):
+        res = ds2[k]
+    res = ds2[1:len(ds):3]
+    
+        
+
+    
+
+
 def test_all(array,ds):
     assert len(ds)==10
-
     test_iterate_over_examples(array, ds)
+    test_overrides(ds)
     test_getitem(array, ds)
     test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z'))
     test_fields_fct(ds)
 
+
 class T_DataSet(unittest.TestCase):
     def test_ArrayDataSet(self):
         #don't test stream