changeset 770:742972b6906a

scall optimisation to FillMissing.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 09 Jun 2009 21:39:50 -0400
parents 0ff7ac3253b3
children 72730f38d1fb
files pylearn/sandbox/scan_inputs_groups.py
diffstat 1 files changed, 13 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sandbox/scan_inputs_groups.py	Tue Jun 09 21:37:58 2009 -0400
+++ b/pylearn/sandbox/scan_inputs_groups.py	Tue Jun 09 21:39:50 2009 -0400
@@ -607,18 +607,25 @@
         out[0] = input.copy()
         out = out[0]
         mask = output_storage[1]
-        mask[0] = numpy.ones(input.shape)
+        
+        if mask[0] is None or mask[0].shape!=input.shape:
+            mask[0] = numpy.ones(input.shape)
+
         mask = mask[0]
         if self.fill_with_is_array:
             ignore_k = len(out.shape) - len(self.fill_with.shape)
             assert ignore_k >= 0
-        for (idx, v) in numpy.ndenumerate(out):
-            if numpy.isnan(v):
-                if self.fill_with_is_array:
+        
+        if self.fill_with_is_array:
+            for (idx, v) in numpy.ndenumerate(out):
+                if numpy.isnan(v):
                     out[idx] = self.fill_with[idx[ignore_k:]]
-                else:
+                    mask[idx] = 0
+        else:
+            for (idx, v) in numpy.ndenumerate(out):
+                if numpy.isnan(v):
                     out[idx] = self.fill_with
-                mask[idx] = 0
+                    mask[idx] = 0
 
     def grad(self, inputs, (out_grad, mask_grad, )):
         return [out_grad]