Mercurial > pylearn
annotate pylearn/dataset_ops/protocol.py @ 841:d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 22 Oct 2009 18:52:47 -0400 |
parents | 67b92a42f86b |
children | b2821fce15de |
rev | line source |
---|---|
832 | 1 """Convenience base classes to help with writing Dataset ops |
2 | |
3 """ | |
4 | |
5 __docformat__ = "restructuredtext_en" | |
6 import theano | |
7 | |
8 class Dataset(theano.Op): | |
9 """ | |
10 The basic dataset interface is an expression that maps an integer to a dataset element. | |
11 | |
12 There is also a minibatch option, in which the expression maps an array of integers to a | |
13 list or array of dataset elements. | |
14 """ | |
15 def __init__(self, single_type, batch_type): | |
16 self.single_type = single_type | |
17 self.batch_type = batch_type | |
18 | |
19 def make_node(self, idx): | |
20 _idx = theano.tensor.as_tensor_variable(idx) | |
21 if not _idx.dtype.startswith('int'): | |
22 raise TypeError() | |
23 if _idx.ndim == 0: # one example at a time | |
24 otype = self.single_type | |
25 elif _idx.ndim == 1: #many examples at a time | |
26 otype = self.batch_type | |
27 else: | |
28 raise TypeError(idx) | |
29 return theano.Apply(self, [_idx], [otype()]) | |
30 | |
31 def __eq__(self, other): | |
32 return type(self) == type(other) \ | |
33 and self.single_type == other.single_type \ | |
34 and self.batch_type == other.batch_type | |
35 | |
36 def __hash__(self): | |
37 return hash(type(self)) ^ hash(self.single_type) ^ hash(self.batch_type) | |
38 | |
39 def __str__(self): | |
40 return "%s{%s,%s}" % (self.__class__.__name__, self.single_type, self.batch_type) | |
41 | |
42 def grad(self, inputs, g_outputs): | |
43 return [None for i in inputs] | |
44 | |
45 | |
46 class TensorDataset(Dataset): | |
47 """A convenient base class for Datasets whose elements all have the same TensorType. | |
48 """ | |
49 def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None): | |
50 single_broadcastable = tuple(single_broadcastable) | |
51 single_type = theano.tensor.Tensor( | |
52 broadcastable=single_broadcastable, | |
53 dtype=dtype, | |
54 shape=single_shape) | |
55 batch_type = theano.tensor.Tensor( | |
56 broadcastable=(False,)+single_type.broadcastable, | |
57 dtype=dtype, | |
58 shape=(batch_size,)+single_type.shape) | |
59 super(TensorDataset, self).__init__(single_type, batch_type) | |
60 | |
61 class TensorFnDataset(TensorDataset): | |
62 def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None): | |
63 super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size) | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
64 try: |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
65 self.fn, self.fn_args = fn |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
66 except: |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
67 self.fn, self.fn_args = fn, () |
832 | 68 |
69 def __eq__(self, other): | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
70 return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \ |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
71 and self.fn_args == other.fn_args |
832 | 72 |
73 def __hash__(self): | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
74 return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) ^ hash(self.fn_args) |
832 | 75 |
76 def __str__(self): | |
77 try: | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
78 return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args) |
832 | 79 except: |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
80 return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args) |
832 | 81 |
82 def perform(self, node, (idx,), (z,)): | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
83 x = self.fn(*self.fn_args) |
832 | 84 if idx.ndim == 0: |
85 z[0] = x[int(idx)] | |
86 else: | |
87 z[0] = x[idx] |