Mercurial > pylearn
changeset 717:bf29e201515f
New class MaskSelect (should probably be moved to another file)
author | Olivier Delalleau <delallea@iro> |
---|---|
date | Mon, 25 May 2009 11:42:25 -0400 |
parents | d7c6dadb4aa9 |
children | 88f5b75a4afe |
files | pylearn/sandbox/scan_inputs_groups.py |
diffstat | 1 files changed, 28 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/sandbox/scan_inputs_groups.py Mon May 25 10:23:00 2009 -0400 +++ b/pylearn/sandbox/scan_inputs_groups.py Mon May 25 11:42:25 2009 -0400 @@ -561,6 +561,8 @@ scanmaskenc=ScanMask(True) scanmaskdec=ScanMask(False) +# TODO The classes FillMissing and MaskSelect below should probably be moved +# to another (more appropriate) file. class FillMissing(Op): """ Given an input, output two elements: @@ -598,3 +600,29 @@ fill_missing_with_zeros = FillMissing(0) +class MaskSelect(Op): + """ + Given an input x and a mask m (both vectors), outputs a vector that + contains all elements x[i] such that bool(m[i]) is True. + """ + + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash(type(self)) + + def make_node(self, input, mask): + return Apply(self, [input, mask], [input.type()]) + + def perform(self, node, (input, mask), (output, )): + select = [] + for (i, m) in enumerate(mask): + if bool(m): + select.append(i) + output[0] = numpy.zeros(len(select), dtype = input.dtype) + out = output[0] + for (i, j) in enumerate(select): + out[i] = input[j] + +mask_select = MaskSelect()