Mercurial > pylearn
changeset 841:d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 22 Oct 2009 18:52:47 -0400 |
parents | 7ccce98da2b6 |
children | 3c1fb6f14a14 |
files | pylearn/dataset_ops/protocol.py |
diffstat | 1 files changed, 10 insertions(+), 6 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/dataset_ops/protocol.py Thu Oct 22 18:52:12 2009 -0400 +++ b/pylearn/dataset_ops/protocol.py Thu Oct 22 18:52:47 2009 -0400 @@ -61,22 +61,26 @@ class TensorFnDataset(TensorDataset): def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None): super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size) - self.fn = fn + try: + self.fn, self.fn_args = fn + except: + self.fn, self.fn_args = fn, () def __eq__(self, other): - return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn + return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \ + and self.fn_args == other.fn_args def __hash__(self): - return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) + return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) ^ hash(self.fn_args) def __str__(self): try: - return "%s{%s}" % (self.__class__.__name__, self.fn.__name__) + return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args) except: - return "%s{%s}" % (self.__class__.__name__, self.fn) + return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args) def perform(self, node, (idx,), (z,)): - x = self.fn() + x = self.fn(*self.fn_args) if idx.ndim == 0: z[0] = x[int(idx)] else: