changeset 775:164a76c8346f

test more case in FillMissing and print the time of the op.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Thu, 11 Jun 2009 13:35:31 -0400
parents 41761210d16e
children 72ce8288a283
files pylearn/sandbox/test_scan_inputs_groups.py
diffstat 1 files changed, 43 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sandbox/test_scan_inputs_groups.py	Thu Jun 11 11:37:24 2009 -0400
+++ b/pylearn/sandbox/test_scan_inputs_groups.py	Thu Jun 11 13:35:31 2009 -0400
@@ -1,6 +1,6 @@
 import sys, time, unittest
 
-import numpy
+import numpy,time
 import numpy as N
 
 from theano.tests import unittest_tools as utt
@@ -14,6 +14,7 @@
         utt.seed_rng()
 
     def test_vector(self):
+        print "test_vector"
         n=100000
         v=T.dvector()
         def t(prob,val,fill):
@@ -24,14 +25,17 @@
                 if prob[i]<0.1:
                     nb_missing+=1
                     val[i]=N.nan
+            t=time.time()
             out=f(val)
+            print "time %.3fs"%(time.time()-t)
             for i in range(n):
                 if N.isnan(val[i]):
                     if isinstance(fill,N.ndarray):
-                        assert out[0][i]==fill[i]
+                        assert abs(out[0][i]-fill[i])<1e-6
                     else:
                         assert out[0][i]==fill
                 else:
+                    assert out[0][i]==val[i]
                     assert out[1][i]==1                
 
         prob=N.random.random(n)
@@ -45,32 +49,47 @@
 
 #TODO: test fill_with_array!
     def test_matrix(self):
-        shp=(100,100)
+        print "test_matrix"
+        shp=(100,10)
         v=T.dmatrix()
-        op=FillMissing()(v)
-        fct=function([v],op)
+        def t(prob,val,fill):
+            op=FillMissing(fill)(v)
+            f=function([v],op)
         
+            nb_missing=0
+            for i in range(shp[0]):
+                for j in range(shp[1]):
+                    if prob[i,j]<0.1:
+                        nb_missing+=1
+                        val[i,j]=N.nan
+            t=time.time()
+            out=f(val)
+            print "time %.3fs"%(time.time()-t)
+            for i in range(shp[0]):
+                for j in range(shp[1]):
+                    if N.isnan(val[i,j]):
+                        if isinstance(fill,N.ndarray):
+                            assert abs(out[0][i,j]-fill[j])<1e-6
+                        else:
+                            assert out[0][i,j]==fill
+                    else:
+                        assert out[0][i,j]==val[i,j]
+                        assert out[1][i,j]==1
+
+
         prob=N.random.random(N.prod(shp)).reshape(shp)
         val=N.random.random(shp)
-        nb_missing=0
-        for i in range(shp[0]):
-            for j in range(shp[1]):
-                if prob[i,j]<0.1:
-                    nb_missing+=1
-                    val[i,j]=N.nan
 
-        out=fct(val)
-        for i in range(shp[0]):
-            for j in range(shp[1]):
-                if N.isnan(val[i,j]):
-                    assert out[0][i,j]==0
-                    assert out[1][i,j]==0
-                else:
-                    assert out[1][i,j]==1
- 
+        fill=0
+        t(prob,val,fill)#test with fill a constant
+        
+        fill=N.random.random(shp[1])
+        t(prob,val,fill)#test with fill a vector
+
 #TODO: test fill_with_array!
     def test_matrix3d(self):
-        shp=(100,100,100)
+        print "test_matrix3d"
+        shp=(10,100,100)
         v= T.TensorType('float64', (False, False, False))()
         op=FillMissing()(v)
         fct=function([v],op)
@@ -84,8 +103,9 @@
                     if prob[i,j,k]<0.1:
                         nb_missing+=1
                         val[i,j,k]=N.nan
-
+        t=time.time()
         out=fct(val)
+        print "time %.3fs"%(time.time()-t)
         for i in range(shp[0]):
             for j in range(shp[1]):
                 for k in range(shp[2]):
@@ -93,6 +113,7 @@
                         assert out[0][i,j,k]==0
                         assert out[1][i,j,k]==0
                     else:
+                        assert out[0][i,j,k]==val[i,j,k]
                         assert out[1][i,j,k]==1
 
 if __name__ == '__main__':