changeset 841:d7ee9c906d7e

added the ability to pass fn_args to TensorFnDataset
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 22 Oct 2009 18:52:47 -0400
parents 7ccce98da2b6
children 3c1fb6f14a14
files pylearn/dataset_ops/protocol.py
diffstat 1 files changed, 10 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dataset_ops/protocol.py	Thu Oct 22 18:52:12 2009 -0400
+++ b/pylearn/dataset_ops/protocol.py	Thu Oct 22 18:52:47 2009 -0400
@@ -61,22 +61,26 @@
 class TensorFnDataset(TensorDataset):
     def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None):
         super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size)
-        self.fn = fn
+        try:
+            self.fn, self.fn_args = fn
+        except:
+            self.fn, self.fn_args = fn, ()
 
     def __eq__(self, other):
-        return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn
+        return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \
+                and self.fn_args == other.fn_args
 
     def __hash__(self):
-        return super(TensorFnDataset, self).__hash__() ^ hash(self.fn)
+        return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) ^ hash(self.fn_args)
 
     def __str__(self):
         try:
-            return "%s{%s}" % (self.__class__.__name__, self.fn.__name__)
+            return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args)
         except:
-            return "%s{%s}" % (self.__class__.__name__, self.fn)
+            return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args)
 
     def perform(self, node, (idx,), (z,)):
-        x = self.fn()
+        x = self.fn(*self.fn_args)
         if idx.ndim == 0:
             z[0] = x[int(idx)]
         else: