Mercurial > pylearn
comparison pylearn/dataset_ops/protocol.py @ 1417:f49801e39fe3
TensorDataset - added single_shape and batch_size properties
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 04 Feb 2011 16:02:57 -0500 |
parents | 5cbd235f8fb8 |
children | 383d4c061546 |
comparison
equal
deleted
inserted
replaced
1416:28b2f17991aa | 1417:f49801e39fe3 |
---|---|
46 class TensorDataset(Dataset): | 46 class TensorDataset(Dataset): |
47 """A convenient base class for Datasets whose elements all have the same TensorType. | 47 """A convenient base class for Datasets whose elements all have the same TensorType. |
48 """ | 48 """ |
49 def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None): | 49 def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None): |
50 single_broadcastable = tuple(single_broadcastable) | 50 single_broadcastable = tuple(single_broadcastable) |
51 self.single_shape = single_shape | |
52 self.batch_size = batch_size | |
51 single_type = theano.tensor.Tensor( | 53 single_type = theano.tensor.Tensor( |
52 broadcastable=single_broadcastable, | 54 broadcastable=single_broadcastable, |
53 dtype=dtype) | 55 dtype=dtype) |
54 #shape=single_shape) | |
55 batch_type = theano.tensor.Tensor( | 56 batch_type = theano.tensor.Tensor( |
56 broadcastable=(False,)+single_type.broadcastable, | 57 broadcastable=(False,)+single_type.broadcastable, |
57 dtype=dtype) | 58 dtype=dtype) |
58 #shape=(batch_size,)+single_type.shape) | |
59 super(TensorDataset, self).__init__(single_type, batch_type) | 59 super(TensorDataset, self).__init__(single_type, batch_type) |
60 def __eq__(self, other): | |
61 return (super(TensorDataset, self).__eq__(other) | |
62 and self.single_shape == other.single_shape | |
63 and self.batch_size == other.batch_size) | |
64 def __hash__(self): | |
65 return (super(TensorDataset, self).__hash__() | |
66 ^ hash(self.single_shape) | |
67 ^ hash(self.batch_size)) | |
60 | 68 |
61 class TensorFnDataset(TensorDataset): | 69 class TensorFnDataset(TensorDataset): |
62 """A good base class for TensorDatasets that can be read from disk and cached in memory | 70 """A good base class for TensorDatasets that can be read from disk and cached in memory |
63 | 71 |
64 The dataset is accessed via a function call to make this Op pickle-able. If the function | 72 The dataset is accessed via a function call to make this Op pickle-able. If the function |