# HG changeset patch # User Olivier Delalleau # Date 1243261380 14400 # Node ID d7c6dadb4aa9d1277059b1b6edfea2eacfe518ef # Parent a268c5ea0db4086008b0ce0430603f2d3971900a New class FillMissing (should probably be moved to another file later) diff -r a268c5ea0db4 -r d7c6dadb4aa9 pylearn/sandbox/scan_inputs_groups.py --- a/pylearn/sandbox/scan_inputs_groups.py Fri May 22 14:13:38 2009 -0400 +++ b/pylearn/sandbox/scan_inputs_groups.py Mon May 25 10:23:00 2009 -0400 @@ -560,3 +560,41 @@ scanmaskenc=ScanMask(True) scanmaskdec=ScanMask(False) + +class FillMissing(Op): + """ + Given an input, output two elements: + - a copy of the input where missing values (NaN) are replaced by some + constant (zero by default) + - a boolean (actually int8) mask of the same size as input, where each + element is True (i.e. 1) iff the corresponding input is not missing + """ + + def __init__(self, constant_val=0): + super(Op, self).__init__() + self.constant_val = constant_val + + def __eq__(self, other): + return type(self) == type(other) and (self.constant_val == other.constant_val) + + def __hash__(self): + return hash(type(self))^hash(self.constant_val) + + def make_node(self, input): + return Apply(self, [input], [input.type(), T.bmatrix()]) + + def perform(self, node, inputs, output_storage): + input = inputs[0] + out = output_storage[0] + out[0] = input.copy() + out = out[0] + mask = output_storage[1] + mask[0] = numpy.ones(input.shape, dtype = numpy.int8) + mask = mask[0] + for (idx, v) in numpy.ndenumerate(out): + if numpy.isnan(v): + out[idx] = self.constant_val + mask[idx] = 0 + +fill_missing_with_zeros = FillMissing(0) +