diff pylearn/sandbox/scan_inputs_groups.py @ 771:72730f38d1fb

opt of the FillMissing op. Now 80-90% faster python implementation.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Wed, 10 Jun 2009 13:42:56 -0400
parents 742972b6906a
children b6670cb57101
line wrap: on
line diff
--- a/pylearn/sandbox/scan_inputs_groups.py	Tue Jun 09 21:39:50 2009 -0400
+++ b/pylearn/sandbox/scan_inputs_groups.py	Wed Jun 10 13:42:56 2009 -0400
@@ -613,19 +613,46 @@
 
         mask = mask[0]
         if self.fill_with_is_array:
-            ignore_k = len(out.shape) - len(self.fill_with.shape)
-            assert ignore_k >= 0
-        
-        if self.fill_with_is_array:
-            for (idx, v) in numpy.ndenumerate(out):
-                if numpy.isnan(v):
-                    out[idx] = self.fill_with[idx[ignore_k:]]
-                    mask[idx] = 0
+            #numpy.ndenumerate is slower then a loop
+            #so we optimise for some number of dimension frequently used
+            if out.ndim==1:
+                assert self.fill_with.ndim==1
+                for i in range(out.shape[0]):
+                    if numpy.isnan(out[i]):
+                        out[i] = self.fill_with[i]
+                        mask[i] = 0
+            elif out.ndim==2 and self.fill_with.ndim==1:
+                for i in range(out.shape[0]):
+                    for j in range(out.shape[1]):
+                        if numpy.isnan(out[i,j]):
+                            out[i,j] = self.fill_with[j]
+                            mask[i,j] = 0
+            else:
+                ignore_k = out.ndim - self.fill_with.ndim
+                assert ignore_k >= 0
+                for (idx, v) in numpy.ndenumerate(out):
+                    if numpy.isnan(v):
+                        out[idx] = self.fill_with[idx[ignore_k:]]
+                        mask[idx] = 0
         else:
-            for (idx, v) in numpy.ndenumerate(out):
-                if numpy.isnan(v):
-                    out[idx] = self.fill_with
-                    mask[idx] = 0
+            #numpy.ndenumerate is slower then a loop
+            #so we optimise for some number of dimension frequently used
+            if out.ndim==1:
+                for i in range(out.shape[0]):
+                    if numpy.isnan(out[i]):
+                        out[i] = self.fill_with
+                        mask[i] = 0
+            elif out.ndim==2:
+                for i in range(out.shape[0]):
+                    for j in range(out.shape[1]):
+                        if numpy.isnan(out[i,j]):
+                            out[i,j] = self.fill_with
+                            mask[i,j] = 0
+            else:
+                for (idx, v) in numpy.ndenumerate(out):
+                    if numpy.isnan(out[idx]):
+                        out[idx] = self.fill_with
+                        mask[idx] = 0
 
     def grad(self, inputs, (out_grad, mask_grad, )):
         return [out_grad]