# HG changeset patch
# User Olivier Delalleau <delallea@iro>
# Date 1243261380 14400
# Node ID d7c6dadb4aa9d1277059b1b6edfea2eacfe518ef
# Parent  a268c5ea0db4086008b0ce0430603f2d3971900a
New class FillMissing (should probably be moved to another file later)

diff -r a268c5ea0db4 -r d7c6dadb4aa9 pylearn/sandbox/scan_inputs_groups.py
--- a/pylearn/sandbox/scan_inputs_groups.py	Fri May 22 14:13:38 2009 -0400
+++ b/pylearn/sandbox/scan_inputs_groups.py	Mon May 25 10:23:00 2009 -0400
@@ -560,3 +560,41 @@
 
 scanmaskenc=ScanMask(True)
 scanmaskdec=ScanMask(False)
+
+class FillMissing(Op):
+    """
+    Given an input, output two elements:
+        - a copy of the input where missing values (NaN) are replaced by some
+        constant (zero by default)
+        - a boolean (actually int8) mask of the same size as input, where each
+        element is True (i.e. 1) iff the corresponding input is not missing
+    """
+
+    def __init__(self, constant_val=0):
+        super(Op, self).__init__()
+        self.constant_val = constant_val
+
+    def __eq__(self, other):
+        return type(self) == type(other) and (self.constant_val == other.constant_val)
+
+	def __hash__(self):
+		return hash(type(self))^hash(self.constant_val)
+	
+    def make_node(self, input):
+        return Apply(self, [input], [input.type(), T.bmatrix()])
+
+    def perform(self, node, inputs, output_storage):
+        input = inputs[0]
+        out = output_storage[0]
+        out[0] = input.copy()
+        out = out[0]
+        mask = output_storage[1]
+        mask[0] = numpy.ones(input.shape, dtype = numpy.int8)
+        mask = mask[0]
+        for (idx, v) in numpy.ndenumerate(out):
+            if numpy.isnan(v):
+                out[idx] = self.constant_val
+                mask[idx] = 0
+
+fill_missing_with_zeros = FillMissing(0)
+