Mercurial > pylearn
annotate pylearn/dataset_ops/protocol.py @ 1417:f49801e39fe3
TensorDataset - added single_shape and batch_size properties
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 04 Feb 2011 16:02:57 -0500 |
parents | 5cbd235f8fb8 |
children | 383d4c061546 |
rev | line source |
---|---|
832 | 1 """Convenience base classes to help with writing Dataset ops |
2 | |
3 """ | |
4 | |
5 __docformat__ = "restructuredtext_en" | |
6 import theano | |
7 | |
8 class Dataset(theano.Op): | |
9 """ | |
10 The basic dataset interface is an expression that maps an integer to a dataset element. | |
11 | |
12 There is also a minibatch option, in which the expression maps an array of integers to a | |
13 list or array of dataset elements. | |
14 """ | |
15 def __init__(self, single_type, batch_type): | |
16 self.single_type = single_type | |
17 self.batch_type = batch_type | |
18 | |
19 def make_node(self, idx): | |
20 _idx = theano.tensor.as_tensor_variable(idx) | |
21 if not _idx.dtype.startswith('int'): | |
22 raise TypeError() | |
23 if _idx.ndim == 0: # one example at a time | |
24 otype = self.single_type | |
25 elif _idx.ndim == 1: #many examples at a time | |
26 otype = self.batch_type | |
27 else: | |
28 raise TypeError(idx) | |
29 return theano.Apply(self, [_idx], [otype()]) | |
30 | |
31 def __eq__(self, other): | |
32 return type(self) == type(other) \ | |
33 and self.single_type == other.single_type \ | |
34 and self.batch_type == other.batch_type | |
35 | |
36 def __hash__(self): | |
37 return hash(type(self)) ^ hash(self.single_type) ^ hash(self.batch_type) | |
38 | |
39 def __str__(self): | |
40 return "%s{%s,%s}" % (self.__class__.__name__, self.single_type, self.batch_type) | |
41 | |
42 def grad(self, inputs, g_outputs): | |
43 return [None for i in inputs] | |
44 | |
45 | |
46 class TensorDataset(Dataset): | |
47 """A convenient base class for Datasets whose elements all have the same TensorType. | |
48 """ | |
49 def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None): | |
50 single_broadcastable = tuple(single_broadcastable) | |
1417
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
51 self.single_shape = single_shape |
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
52 self.batch_size = batch_size |
832 | 53 single_type = theano.tensor.Tensor( |
1417
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
54 broadcastable=single_broadcastable, |
930
37ed715ac034
removed shape argument from tensor constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
872
diff
changeset
|
55 dtype=dtype) |
832 | 56 batch_type = theano.tensor.Tensor( |
57 broadcastable=(False,)+single_type.broadcastable, | |
930
37ed715ac034
removed shape argument from tensor constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
872
diff
changeset
|
58 dtype=dtype) |
832 | 59 super(TensorDataset, self).__init__(single_type, batch_type) |
1417
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
60 def __eq__(self, other): |
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
61 return (super(TensorDataset, self).__eq__(other) |
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
62 and self.single_shape == other.single_shape |
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
63 and self.batch_size == other.batch_size) |
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
64 def __hash__(self): |
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
65 return (super(TensorDataset, self).__hash__() |
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
66 ^ hash(self.single_shape) |
f49801e39fe3
TensorDataset - added single_shape and batch_size properties
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
957
diff
changeset
|
67 ^ hash(self.batch_size)) |
832 | 68 |
69 class TensorFnDataset(TensorDataset): | |
872
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
70 """A good base class for TensorDatasets that can be read from disk and cached in memory |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
71 |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
72 The dataset is accessed via a function call to make this Op pickle-able. If the function |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
73 is a normal module-level function, then this Op will be picklable. If the dataset were a |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
74 property of this Op, then pickling the Op would require pickling the entire dataset. |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
75 """ |
832 | 76 def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None): |
957
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
77 """ |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
78 :type fn: callable or (callable, args) tuple [MUST BE PICKLABLE!] |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
79 |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
80 :param fn: function that returns the dataset as a tensor. Leading index is the example |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
81 index, others are considered part of each example. |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
82 """ |
832 | 83 super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size) |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
84 try: |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
85 self.fn, self.fn_args = fn |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
86 except: |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
87 self.fn, self.fn_args = fn, () |
832 | 88 |
89 def __eq__(self, other): | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
90 return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \ |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
91 and self.fn_args == other.fn_args |
832 | 92 |
93 def __hash__(self): | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
94 return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) ^ hash(self.fn_args) |
832 | 95 |
96 def __str__(self): | |
97 try: | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
98 return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args) |
832 | 99 except: |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
100 return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args) |
832 | 101 |
102 def perform(self, node, (idx,), (z,)): | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
103 x = self.fn(*self.fn_args) |
832 | 104 if idx.ndim == 0: |
105 z[0] = x[int(idx)] | |
106 else: | |
107 z[0] = x[idx] |