changeset 772:33f46eee4a96

added more test for the FillMissing op.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Wed, 10 Jun 2009 13:43:27 -0400
parents 72730f38d1fb
children a25d2229a091
files pylearn/sandbox/test_scan_inputs_groups.py
diffstat 1 files changed, 82 insertions(+), 23 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sandbox/test_scan_inputs_groups.py	Wed Jun 10 13:42:56 2009 -0400
+++ b/pylearn/sandbox/test_scan_inputs_groups.py	Wed Jun 10 13:43:27 2009 -0400
@@ -7,38 +7,97 @@
 
 from theano import function, Mode
 import theano.tensor as T
-from pylearn.sandbox.scan_inputs_groups import FillMissing
-
-
-if __name__ == '__main__':
-    t = TestConvOp("test_convolution")
-    t.test_convolution()
-    t.test_multilayer_conv()
-    from theano.tests import main
-    main("test_sp")
+from pylearn.sandbox.scan_inputs_groups2 import FillMissing
 
 class TestFillMissing(unittest.TestCase):
     def setUp(self):
         utt.seed_rng()
 
-    def test_base(self):
+    def test_vector(self):
+        n=100000
         v=T.dvector()
+        def t(prob,val,fill):
+            op=FillMissing(fill)(v)
+            f=function([v],op)
+            nb_missing=0
+            for i in range(n):
+                if prob[i]<0.1:
+                    nb_missing+=1
+                    val[i]=N.nan
+            out=f(val)
+            for i in range(n):
+                if N.isnan(val[i]):
+                    if isinstance(fill,N.ndarray):
+                        assert out[0][i]==fill[i]
+                    else:
+                        assert out[0][i]==fill
+                else:
+                    assert out[1][i]==1                
+
+        prob=N.random.random(n)
+        val=N.random.random(n)
+
+        fill=0
+        t(prob,val,fill)#test with fill a constant
+
+        fill=N.random.random(n)
+        t(prob,val,fill)#test with fill a vector
+
+#TODO: test fill_with_array!
+    def test_matrix(self):
+        shp=(100,100)
+        v=T.dmatrix()
         op=FillMissing()(v)
         fct=function([v],op)
         
-        prob=N.random.random(1000)
-        val=N.random.random(len(prob))
+        prob=N.random.random(N.prod(shp)).reshape(shp)
+        val=N.random.random(shp)
         nb_missing=0
-        for i in range(len(val)):
-            if prob[i]<0.1:
-                nb_missing+=1
-                val[i]=N.nan
+        for i in range(shp[0]):
+            for j in range(shp[1]):
+                if prob[i,j]<0.1:
+                    nb_missing+=1
+                    val[i,j]=N.nan
 
         out=fct(val)
-        for i in range(len(prob)):
-            if N.isnan(val[i]):
-                assert out[0][i]==0
-                assert out[1][i]==0
-            else:
-                assert out[1][i]==1
-                
+        for i in range(shp[0]):
+            for j in range(shp[1]):
+                if N.isnan(val[i,j]):
+                    assert out[0][i,j]==0
+                    assert out[1][i,j]==0
+                else:
+                    assert out[1][i,j]==1
+ 
+#TODO: test fill_with_array!
+    def test_matrix3d(self):
+        shp=(100,100,100)
+        v= T.TensorType('float64', (False, False, False))()
+        op=FillMissing()(v)
+        fct=function([v],op)
+        
+        prob=N.random.random(N.prod(shp)).reshape(shp)
+        val=N.random.random(prob.shape)
+        nb_missing=0
+        for i in range(shp[0]):
+            for j in range(shp[1]):
+                for k in range(shp[2]):
+                    if prob[i,j,k]<0.1:
+                        nb_missing+=1
+                        val[i,j,k]=N.nan
+
+        out=fct(val)
+        for i in range(shp[0]):
+            for j in range(shp[1]):
+                for k in range(shp[2]):
+                    if N.isnan(val[i,j,k]):
+                        assert out[0][i,j,k]==0
+                        assert out[1][i,j,k]==0
+                    else:
+                        assert out[1][i,j,k]==1
+
+if __name__ == '__main__':
+    t = TestFillMissing("test_vector")
+    t.test_vector()
+#    from theano.tests import main
+#    main("test_sp")
+