# HG changeset patch # User Olivier Delalleau # Date 1243266145 14400 # Node ID bf29e201515fb23f5b859ae12e83230e7ebca9af # Parent d7c6dadb4aa9d1277059b1b6edfea2eacfe518ef New class MaskSelect (should probably be moved to another file) diff -r d7c6dadb4aa9 -r bf29e201515f pylearn/sandbox/scan_inputs_groups.py --- a/pylearn/sandbox/scan_inputs_groups.py Mon May 25 10:23:00 2009 -0400 +++ b/pylearn/sandbox/scan_inputs_groups.py Mon May 25 11:42:25 2009 -0400 @@ -561,6 +561,8 @@ scanmaskenc=ScanMask(True) scanmaskdec=ScanMask(False) +# TODO The classes FillMissing and MaskSelect below should probably be moved +# to another (more appropriate) file. class FillMissing(Op): """ Given an input, output two elements: @@ -598,3 +600,29 @@ fill_missing_with_zeros = FillMissing(0) +class MaskSelect(Op): + """ + Given an input x and a mask m (both vectors), outputs a vector that + contains all elements x[i] such that bool(m[i]) is True. + """ + + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash(type(self)) + + def make_node(self, input, mask): + return Apply(self, [input, mask], [input.type()]) + + def perform(self, node, (input, mask), (output, )): + select = [] + for (i, m) in enumerate(mask): + if bool(m): + select.append(i) + output[0] = numpy.zeros(len(select), dtype = input.dtype) + out = output[0] + for (i, j) in enumerate(select): + out[i] = input[j] + +mask_select = MaskSelect()