Mercurial > pylearn
changeset 716:d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
author | Olivier Delalleau <delallea@iro> |
---|---|
date | Mon, 25 May 2009 10:23:00 -0400 |
parents | a268c5ea0db4 |
children | bf29e201515f |
files | pylearn/sandbox/scan_inputs_groups.py |
diffstat | 1 files changed, 38 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/sandbox/scan_inputs_groups.py Fri May 22 14:13:38 2009 -0400 +++ b/pylearn/sandbox/scan_inputs_groups.py Mon May 25 10:23:00 2009 -0400 @@ -560,3 +560,41 @@ scanmaskenc=ScanMask(True) scanmaskdec=ScanMask(False) + +class FillMissing(Op): + """ + Given an input, output two elements: + - a copy of the input where missing values (NaN) are replaced by some + constant (zero by default) + - a boolean (actually int8) mask of the same size as input, where each + element is True (i.e. 1) iff the corresponding input is not missing + """ + + def __init__(self, constant_val=0): + super(Op, self).__init__() + self.constant_val = constant_val + + def __eq__(self, other): + return type(self) == type(other) and (self.constant_val == other.constant_val) + + def __hash__(self): + return hash(type(self))^hash(self.constant_val) + + def make_node(self, input): + return Apply(self, [input], [input.type(), T.bmatrix()]) + + def perform(self, node, inputs, output_storage): + input = inputs[0] + out = output_storage[0] + out[0] = input.copy() + out = out[0] + mask = output_storage[1] + mask[0] = numpy.ones(input.shape, dtype = numpy.int8) + mask = mask[0] + for (idx, v) in numpy.ndenumerate(out): + if numpy.isnan(v): + out[idx] = self.constant_val + mask[idx] = 0 + +fill_missing_with_zeros = FillMissing(0) +