changeset 716:d7c6dadb4aa9

New class FillMissing (should probably be moved to another file later)
author Olivier Delalleau <delallea@iro>
date Mon, 25 May 2009 10:23:00 -0400
parents a268c5ea0db4
children bf29e201515f
files pylearn/sandbox/scan_inputs_groups.py
diffstat 1 files changed, 38 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sandbox/scan_inputs_groups.py	Fri May 22 14:13:38 2009 -0400
+++ b/pylearn/sandbox/scan_inputs_groups.py	Mon May 25 10:23:00 2009 -0400
@@ -560,3 +560,41 @@
 
 scanmaskenc=ScanMask(True)
 scanmaskdec=ScanMask(False)
+
+class FillMissing(Op):
+    """
+    Given an input, output two elements:
+        - a copy of the input where missing values (NaN) are replaced by some
+        constant (zero by default)
+        - a boolean (actually int8) mask of the same size as input, where each
+        element is True (i.e. 1) iff the corresponding input is not missing
+    """
+
+    def __init__(self, constant_val=0):
+        super(Op, self).__init__()
+        self.constant_val = constant_val
+
+    def __eq__(self, other):
+        return type(self) == type(other) and (self.constant_val == other.constant_val)
+
+	def __hash__(self):
+		return hash(type(self))^hash(self.constant_val)
+	
+    def make_node(self, input):
+        return Apply(self, [input], [input.type(), T.bmatrix()])
+
+    def perform(self, node, inputs, output_storage):
+        input = inputs[0]
+        out = output_storage[0]
+        out[0] = input.copy()
+        out = out[0]
+        mask = output_storage[1]
+        mask[0] = numpy.ones(input.shape, dtype = numpy.int8)
+        mask = mask[0]
+        for (idx, v) in numpy.ndenumerate(out):
+            if numpy.isnan(v):
+                out[idx] = self.constant_val
+                mask[idx] = 0
+
+fill_missing_with_zeros = FillMissing(0)
+