changeset 717:bf29e201515f

New class MaskSelect (should probably be moved to another file)
author Olivier Delalleau <delallea@iro>
date Mon, 25 May 2009 11:42:25 -0400
parents d7c6dadb4aa9
children 88f5b75a4afe
files pylearn/sandbox/scan_inputs_groups.py
diffstat 1 files changed, 28 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sandbox/scan_inputs_groups.py	Mon May 25 10:23:00 2009 -0400
+++ b/pylearn/sandbox/scan_inputs_groups.py	Mon May 25 11:42:25 2009 -0400
@@ -561,6 +561,8 @@
 scanmaskenc=ScanMask(True)
 scanmaskdec=ScanMask(False)
 
+# TODO The classes FillMissing and MaskSelect below should probably be moved
+# to another (more appropriate) file.
 class FillMissing(Op):
     """
     Given an input, output two elements:
@@ -598,3 +600,29 @@
 
 fill_missing_with_zeros = FillMissing(0)
 
+class MaskSelect(Op):
+    """
+    Given an input x and a mask m (both vectors), outputs a vector that
+    contains all elements x[i] such that bool(m[i]) is True.
+    """
+
+    def __eq__(self, other):
+        return type(self) == type(other)
+
+	def __hash__(self):
+		return hash(type(self))
+	
+    def make_node(self, input, mask):
+        return Apply(self, [input, mask], [input.type()])
+
+    def perform(self, node, (input, mask), (output, )):
+        select = []
+        for (i, m) in enumerate(mask):
+            if bool(m):
+                select.append(i)
+        output[0] = numpy.zeros(len(select), dtype = input.dtype)
+        out = output[0]
+        for (i, j) in enumerate(select):
+            out[i] = input[j]
+
+mask_select = MaskSelect()