# HG changeset patch # User James Bergstra # Date 1296853377 18000 # Node ID f49801e39fe311e702a7989ff5bddf7f441eb17e # Parent 28b2f17991aa03325cb2fb7bd2a4531cffac24b7 TensorDataset - added single_shape and batch_size properties diff -r 28b2f17991aa -r f49801e39fe3 pylearn/dataset_ops/protocol.py --- a/pylearn/dataset_ops/protocol.py Fri Feb 04 16:01:45 2011 -0500 +++ b/pylearn/dataset_ops/protocol.py Fri Feb 04 16:02:57 2011 -0500 @@ -48,15 +48,23 @@ """ def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None): single_broadcastable = tuple(single_broadcastable) + self.single_shape = single_shape + self.batch_size = batch_size single_type = theano.tensor.Tensor( - broadcastable=single_broadcastable, + broadcastable=single_broadcastable, dtype=dtype) - #shape=single_shape) batch_type = theano.tensor.Tensor( broadcastable=(False,)+single_type.broadcastable, dtype=dtype) - #shape=(batch_size,)+single_type.shape) super(TensorDataset, self).__init__(single_type, batch_type) + def __eq__(self, other): + return (super(TensorDataset, self).__eq__(other) + and self.single_shape == other.single_shape + and self.batch_size == other.batch_size) + def __hash__(self): + return (super(TensorDataset, self).__hash__() + ^ hash(self.single_shape) + ^ hash(self.batch_size)) class TensorFnDataset(TensorDataset): """A good base class for TensorDatasets that can be read from disk and cached in memory