annotate pylearn/dataset_ops/shapeset1.py @ 1531:88f361283a19 tip

Fix url/name to pylearn2.
author Frederic Bastien <nouiz@nouiz.org>
date Mon, 09 Sep 2013 10:08:05 -0400
parents 1c62fa857cab
children
rev   line source
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
1 """A Theano Op to load/access Shapeset1
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
2 """
878
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
3 import theano, numpy
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
4
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
5 from .protocol import TensorFnDataset
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
6
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
7 from ..datasets.shapeset1 import head_train, head_valid, head_test
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
8
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
9 #make global functions so Op can be pickled.
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
10 _train_cache = {}
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
11 def train_img(dtype):
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
12 if dtype not in _train_cache:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
13 x, y = head_train()
878
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
14 if dtype.startswith('uint') or dtype.startswith('int'):
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
15 x *= 255
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
16 _train_cache[dtype] = numpy.asarray(x, dtype=dtype)
931
1c62fa857cab forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 879
diff changeset
17 _train_cache['lbl'] = numpy.asarray(y, dtype='int32')
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
18 return _train_cache[dtype]
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
19 def train_lbl():
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
20 if 'lbl' not in _train_cache:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
21 x, y = head_train()
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
22 # cache x in some format now that it's read (it isn't that big).
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
23 _train_cache[x.dtype] = x
931
1c62fa857cab forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 879
diff changeset
24 _train_cache['lbl'] = numpy.asarray(y, dtype='int32')
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
25 return _train_cache['lbl']
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
26 _valid_cache = {}
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
27 def valid_img(dtype):
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
28 if dtype not in _valid_cache:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
29 x, y = head_valid()
878
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
30 if dtype.startswith('uint') or dtype.startswith('int'):
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
31 x *= 255
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
32 _valid_cache[dtype] = numpy.asarray(x, dtype=dtype)
931
1c62fa857cab forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 879
diff changeset
33 _valid_cache['lbl'] = numpy.asarray(y, dtype='int32')
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
34 return _valid_cache[dtype]
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
35 def valid_lbl():
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
36 if 'lbl' not in _valid_cache:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
37 x, y = head_valid()
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
38 # cache x in some format now that it's read (it isn't that big).
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
39 _valid_cache[x.dtype] = x
931
1c62fa857cab forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 879
diff changeset
40 _valid_cache['lbl'] = numpy.asarray(y, dtype='int32')
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
41 return _valid_cache['lbl']
879
0f33afbf517e fixed typo in dataset_ops/shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 878
diff changeset
42 _test_cache = {}
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
43 def test_img(dtype):
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
44 if dtype not in _test_cache:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
45 x, y = head_test()
878
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
46 if dtype.startswith('uint') or dtype.startswith('int'):
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
47 x *= 255
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
48 _test_cache[dtype] = numpy.asarray(x, dtype=dtype)
931
1c62fa857cab forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 879
diff changeset
49 _test_cache['lbl'] = numpy.asarray(y, dtype='int32')
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
50 return _test_cache[dtype]
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
51 def test_lbl():
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
52 if 'lbl' not in _test_cache:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
53 x, y = head_test()
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
54 # cache x in some format now that it's read (it isn't that big).
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
55 _test_cache[x.dtype] = x
931
1c62fa857cab forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 879
diff changeset
56 _test_cache['lbl'] = numpy.asarray(y, dtype='int32')
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
57 return _test_cache['lbl']
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
58
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
59 _split_fns = dict(
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
60 train=(train_img, train_lbl),
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
61 valid=(valid_img, valid_lbl),
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
62 test=(test_img, test_lbl))
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
63
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
64 def shapeset1(s_idx, split, dtype='float64', rasterized=False):
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
65 """
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
66 :param s_idx:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
67
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
68 :param split:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
69
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
70 :param dtype:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
71
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
72 :param rasterized: return examples as vectors (True) or 28x28 matrices (False)
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
73
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
74 """
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
75
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
76 x_fn, y_fn = _split_fns[split]
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
77
878
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
78 x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)),
faa9f880d0d2 fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 874
diff changeset
79 single_shape=(1024,))(s_idx)
931
1c62fa857cab forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 879
diff changeset
80 y = TensorFnDataset(dtype='int32', bcast=(), fn=y_fn)(s_idx)
874
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
81 if x.ndim == 1:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
82 if not rasterized:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
83 x = x.reshape((32,32))
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
84 elif x.ndim == 2:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
85 if not rasterized:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
86 x = x.reshape((x.shape[0], 32,32))
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
87 else:
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
88 assert False, 'what happened?'
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
89 return x, y
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
90 nclasses = 10
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
91
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
92 def glviewer(split='train'):
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
93 from glviewer import GlViewer
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
94 i = theano.tensor.iscalar()
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
95 f = theano.function([i], shapeset1(i, split, dtype='uint8', rasterized=False)[0])
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
96 GlViewer(f).main()
76f71e10f5ef added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
97