Mercurial > pylearn
comparison pylearn/sandbox/scan_inputs_groups.py @ 770:742972b6906a
scall optimisation to FillMissing.
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 09 Jun 2009 21:39:50 -0400 |
parents | bd95e8ea99d8 |
children | 72730f38d1fb |
comparison
equal
deleted
inserted
replaced
769:0ff7ac3253b3 | 770:742972b6906a |
---|---|
605 def perform(self, node, (input, ), output_storage): | 605 def perform(self, node, (input, ), output_storage): |
606 out = output_storage[0] | 606 out = output_storage[0] |
607 out[0] = input.copy() | 607 out[0] = input.copy() |
608 out = out[0] | 608 out = out[0] |
609 mask = output_storage[1] | 609 mask = output_storage[1] |
610 mask[0] = numpy.ones(input.shape) | 610 |
611 if mask[0] is None or mask[0].shape!=input.shape: | |
612 mask[0] = numpy.ones(input.shape) | |
613 | |
611 mask = mask[0] | 614 mask = mask[0] |
612 if self.fill_with_is_array: | 615 if self.fill_with_is_array: |
613 ignore_k = len(out.shape) - len(self.fill_with.shape) | 616 ignore_k = len(out.shape) - len(self.fill_with.shape) |
614 assert ignore_k >= 0 | 617 assert ignore_k >= 0 |
615 for (idx, v) in numpy.ndenumerate(out): | 618 |
616 if numpy.isnan(v): | 619 if self.fill_with_is_array: |
617 if self.fill_with_is_array: | 620 for (idx, v) in numpy.ndenumerate(out): |
621 if numpy.isnan(v): | |
618 out[idx] = self.fill_with[idx[ignore_k:]] | 622 out[idx] = self.fill_with[idx[ignore_k:]] |
619 else: | 623 mask[idx] = 0 |
624 else: | |
625 for (idx, v) in numpy.ndenumerate(out): | |
626 if numpy.isnan(v): | |
620 out[idx] = self.fill_with | 627 out[idx] = self.fill_with |
621 mask[idx] = 0 | 628 mask[idx] = 0 |
622 | 629 |
623 def grad(self, inputs, (out_grad, mask_grad, )): | 630 def grad(self, inputs, (out_grad, mask_grad, )): |
624 return [out_grad] | 631 return [out_grad] |
625 | 632 |
626 fill_missing_with_zeros = FillMissing(0) | 633 fill_missing_with_zeros = FillMissing(0) |