Mercurial > pylearn
annotate pylearn/dataset_ops/protocol.py @ 998:8ba8b08e0442
added the image_patches dataset used in RanzatoHinton2010
modified mcRBM to use it.
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 24 Aug 2010 16:51:53 -0400 |
parents | 5cbd235f8fb8 |
children | f49801e39fe3 |
rev | line source |
---|---|
832 | 1 """Convenience base classes to help with writing Dataset ops |
2 | |
3 """ | |
4 | |
5 __docformat__ = "restructuredtext_en" | |
6 import theano | |
7 | |
8 class Dataset(theano.Op): | |
9 """ | |
10 The basic dataset interface is an expression that maps an integer to a dataset element. | |
11 | |
12 There is also a minibatch option, in which the expression maps an array of integers to a | |
13 list or array of dataset elements. | |
14 """ | |
15 def __init__(self, single_type, batch_type): | |
16 self.single_type = single_type | |
17 self.batch_type = batch_type | |
18 | |
19 def make_node(self, idx): | |
20 _idx = theano.tensor.as_tensor_variable(idx) | |
21 if not _idx.dtype.startswith('int'): | |
22 raise TypeError() | |
23 if _idx.ndim == 0: # one example at a time | |
24 otype = self.single_type | |
25 elif _idx.ndim == 1: #many examples at a time | |
26 otype = self.batch_type | |
27 else: | |
28 raise TypeError(idx) | |
29 return theano.Apply(self, [_idx], [otype()]) | |
30 | |
31 def __eq__(self, other): | |
32 return type(self) == type(other) \ | |
33 and self.single_type == other.single_type \ | |
34 and self.batch_type == other.batch_type | |
35 | |
36 def __hash__(self): | |
37 return hash(type(self)) ^ hash(self.single_type) ^ hash(self.batch_type) | |
38 | |
39 def __str__(self): | |
40 return "%s{%s,%s}" % (self.__class__.__name__, self.single_type, self.batch_type) | |
41 | |
42 def grad(self, inputs, g_outputs): | |
43 return [None for i in inputs] | |
44 | |
45 | |
46 class TensorDataset(Dataset): | |
47 """A convenient base class for Datasets whose elements all have the same TensorType. | |
48 """ | |
49 def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None): | |
50 single_broadcastable = tuple(single_broadcastable) | |
51 single_type = theano.tensor.Tensor( | |
52 broadcastable=single_broadcastable, | |
930
37ed715ac034
removed shape argument from tensor constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
872
diff
changeset
|
53 dtype=dtype) |
37ed715ac034
removed shape argument from tensor constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
872
diff
changeset
|
54 #shape=single_shape) |
832 | 55 batch_type = theano.tensor.Tensor( |
56 broadcastable=(False,)+single_type.broadcastable, | |
930
37ed715ac034
removed shape argument from tensor constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
872
diff
changeset
|
57 dtype=dtype) |
37ed715ac034
removed shape argument from tensor constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
872
diff
changeset
|
58 #shape=(batch_size,)+single_type.shape) |
832 | 59 super(TensorDataset, self).__init__(single_type, batch_type) |
60 | |
61 class TensorFnDataset(TensorDataset): | |
872
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
62 """A good base class for TensorDatasets that can be read from disk and cached in memory |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
63 |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
64 The dataset is accessed via a function call to make this Op pickle-able. If the function |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
65 is a normal module-level function, then this Op will be picklable. If the dataset were a |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
66 property of this Op, then pickling the Op would require pickling the entire dataset. |
b2821fce15de
added comment to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
841
diff
changeset
|
67 """ |
832 | 68 def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None): |
957
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
69 """ |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
70 :type fn: callable or (callable, args) tuple [MUST BE PICKLABLE!] |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
71 |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
72 :param fn: function that returns the dataset as a tensor. Leading index is the example |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
73 index, others are considered part of each example. |
5cbd235f8fb8
added doc to dataset_ops protocol
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
934
diff
changeset
|
74 """ |
832 | 75 super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size) |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
76 try: |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
77 self.fn, self.fn_args = fn |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
78 except: |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
79 self.fn, self.fn_args = fn, () |
832 | 80 |
81 def __eq__(self, other): | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
82 return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \ |
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
83 and self.fn_args == other.fn_args |
832 | 84 |
85 def __hash__(self): | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
86 return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) ^ hash(self.fn_args) |
832 | 87 |
88 def __str__(self): | |
89 try: | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
90 return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args) |
832 | 91 except: |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
92 return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args) |
832 | 93 |
94 def perform(self, node, (idx,), (z,)): | |
841
d7ee9c906d7e
added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
832
diff
changeset
|
95 x = self.fn(*self.fn_args) |
832 | 96 if idx.ndim == 0: |
97 z[0] = x[int(idx)] | |
98 else: | |
99 z[0] = x[idx] |