annotate pylearn/dataset_ops/protocol.py @ 841:d7ee9c906d7e

added the ability to pass fn_args to TensorFnDataset
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 22 Oct 2009 18:52:47 -0400
parents 67b92a42f86b
children b2821fce15de
rev   line source
832
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
1 """Convenience base classes to help with writing Dataset ops
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
2
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
3 """
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
5 __docformat__ = "restructuredtext_en"
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6 import theano
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
7
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
8 class Dataset(theano.Op):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
9 """
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
10 The basic dataset interface is an expression that maps an integer to a dataset element.
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
11
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
12 There is also a minibatch option, in which the expression maps an array of integers to a
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
13 list or array of dataset elements.
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
14 """
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
15 def __init__(self, single_type, batch_type):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
16 self.single_type = single_type
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
17 self.batch_type = batch_type
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
18
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
19 def make_node(self, idx):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
20 _idx = theano.tensor.as_tensor_variable(idx)
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
21 if not _idx.dtype.startswith('int'):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
22 raise TypeError()
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
23 if _idx.ndim == 0: # one example at a time
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
24 otype = self.single_type
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
25 elif _idx.ndim == 1: #many examples at a time
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
26 otype = self.batch_type
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
27 else:
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
28 raise TypeError(idx)
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
29 return theano.Apply(self, [_idx], [otype()])
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
30
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
31 def __eq__(self, other):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
32 return type(self) == type(other) \
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
33 and self.single_type == other.single_type \
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
34 and self.batch_type == other.batch_type
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
35
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
36 def __hash__(self):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
37 return hash(type(self)) ^ hash(self.single_type) ^ hash(self.batch_type)
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
38
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
39 def __str__(self):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
40 return "%s{%s,%s}" % (self.__class__.__name__, self.single_type, self.batch_type)
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
41
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
42 def grad(self, inputs, g_outputs):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
43 return [None for i in inputs]
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
44
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
45
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
46 class TensorDataset(Dataset):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
47 """A convenient base class for Datasets whose elements all have the same TensorType.
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
48 """
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
49 def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
50 single_broadcastable = tuple(single_broadcastable)
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
51 single_type = theano.tensor.Tensor(
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
52 broadcastable=single_broadcastable,
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
53 dtype=dtype,
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
54 shape=single_shape)
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
55 batch_type = theano.tensor.Tensor(
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
56 broadcastable=(False,)+single_type.broadcastable,
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
57 dtype=dtype,
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
58 shape=(batch_size,)+single_type.shape)
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
59 super(TensorDataset, self).__init__(single_type, batch_type)
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
60
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
61 class TensorFnDataset(TensorDataset):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
62 def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
63 super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size)
841
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
64 try:
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
65 self.fn, self.fn_args = fn
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
66 except:
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
67 self.fn, self.fn_args = fn, ()
832
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
68
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
69 def __eq__(self, other):
841
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
70 return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
71 and self.fn_args == other.fn_args
832
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
72
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
73 def __hash__(self):
841
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
74 return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) ^ hash(self.fn_args)
832
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
75
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
76 def __str__(self):
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
77 try:
841
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
78 return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args)
832
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
79 except:
841
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
80 return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args)
832
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
81
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
82 def perform(self, node, (idx,), (z,)):
841
d7ee9c906d7e added the ability to pass fn_args to TensorFnDataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 832
diff changeset
83 x = self.fn(*self.fn_args)
832
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
84 if idx.ndim == 0:
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
85 z[0] = x[int(idx)]
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
86 else:
67b92a42f86b added dataset_ops
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
87 z[0] = x[idx]