# HG changeset patch # User James Bergstra # Date 1256251967 14400 # Node ID d7ee9c906d7eb8bc3878bb6ecc40f819304e20d4 # Parent 7ccce98da2b6438a4bc9bedf9ee0e1d4f9c09fc6 added the ability to pass fn_args to TensorFnDataset diff -r 7ccce98da2b6 -r d7ee9c906d7e pylearn/dataset_ops/protocol.py --- a/pylearn/dataset_ops/protocol.py Thu Oct 22 18:52:12 2009 -0400 +++ b/pylearn/dataset_ops/protocol.py Thu Oct 22 18:52:47 2009 -0400 @@ -61,22 +61,26 @@ class TensorFnDataset(TensorDataset): def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None): super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size) - self.fn = fn + try: + self.fn, self.fn_args = fn + except: + self.fn, self.fn_args = fn, () def __eq__(self, other): - return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn + return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \ + and self.fn_args == other.fn_args def __hash__(self): - return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) + return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) ^ hash(self.fn_args) def __str__(self): try: - return "%s{%s}" % (self.__class__.__name__, self.fn.__name__) + return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args) except: - return "%s{%s}" % (self.__class__.__name__, self.fn) + return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args) def perform(self, node, (idx,), (z,)): - x = self.fn() + x = self.fn(*self.fn_args) if idx.ndim == 0: z[0] = x[int(idx)] else: