Mercurial > pylearn
annotate pylearn/dataset_ops/shapeset1.py @ 1531:88f361283a19 tip
Fix url/name to pylearn2.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Mon, 09 Sep 2013 10:08:05 -0400 |
parents | 1c62fa857cab |
children |
rev | line source |
---|---|
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
1 """A Theano Op to load/access Shapeset1 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
2 """ |
878
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
3 import theano, numpy |
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
4 |
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
5 from .protocol import TensorFnDataset |
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
6 |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
7 from ..datasets.shapeset1 import head_train, head_valid, head_test |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
8 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
9 #make global functions so Op can be pickled. |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
10 _train_cache = {} |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
11 def train_img(dtype): |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
12 if dtype not in _train_cache: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
13 x, y = head_train() |
878
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
14 if dtype.startswith('uint') or dtype.startswith('int'): |
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
15 x *= 255 |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
16 _train_cache[dtype] = numpy.asarray(x, dtype=dtype) |
931
1c62fa857cab
forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
879
diff
changeset
|
17 _train_cache['lbl'] = numpy.asarray(y, dtype='int32') |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
18 return _train_cache[dtype] |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
19 def train_lbl(): |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
20 if 'lbl' not in _train_cache: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
21 x, y = head_train() |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
22 # cache x in some format now that it's read (it isn't that big). |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
23 _train_cache[x.dtype] = x |
931
1c62fa857cab
forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
879
diff
changeset
|
24 _train_cache['lbl'] = numpy.asarray(y, dtype='int32') |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
25 return _train_cache['lbl'] |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
26 _valid_cache = {} |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
27 def valid_img(dtype): |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
28 if dtype not in _valid_cache: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
29 x, y = head_valid() |
878
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
30 if dtype.startswith('uint') or dtype.startswith('int'): |
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
31 x *= 255 |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
32 _valid_cache[dtype] = numpy.asarray(x, dtype=dtype) |
931
1c62fa857cab
forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
879
diff
changeset
|
33 _valid_cache['lbl'] = numpy.asarray(y, dtype='int32') |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
34 return _valid_cache[dtype] |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
35 def valid_lbl(): |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
36 if 'lbl' not in _valid_cache: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
37 x, y = head_valid() |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
38 # cache x in some format now that it's read (it isn't that big). |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
39 _valid_cache[x.dtype] = x |
931
1c62fa857cab
forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
879
diff
changeset
|
40 _valid_cache['lbl'] = numpy.asarray(y, dtype='int32') |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
41 return _valid_cache['lbl'] |
879
0f33afbf517e
fixed typo in dataset_ops/shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
878
diff
changeset
|
42 _test_cache = {} |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
43 def test_img(dtype): |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
44 if dtype not in _test_cache: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
45 x, y = head_test() |
878
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
46 if dtype.startswith('uint') or dtype.startswith('int'): |
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
47 x *= 255 |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
48 _test_cache[dtype] = numpy.asarray(x, dtype=dtype) |
931
1c62fa857cab
forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
879
diff
changeset
|
49 _test_cache['lbl'] = numpy.asarray(y, dtype='int32') |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
50 return _test_cache[dtype] |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
51 def test_lbl(): |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
52 if 'lbl' not in _test_cache: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
53 x, y = head_test() |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
54 # cache x in some format now that it's read (it isn't that big). |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
55 _test_cache[x.dtype] = x |
931
1c62fa857cab
forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
879
diff
changeset
|
56 _test_cache['lbl'] = numpy.asarray(y, dtype='int32') |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
57 return _test_cache['lbl'] |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
58 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
59 _split_fns = dict( |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
60 train=(train_img, train_lbl), |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
61 valid=(valid_img, valid_lbl), |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
62 test=(test_img, test_lbl)) |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
63 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
64 def shapeset1(s_idx, split, dtype='float64', rasterized=False): |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
65 """ |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
66 :param s_idx: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
67 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
68 :param split: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
69 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
70 :param dtype: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
71 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
72 :param rasterized: return examples as vectors (True) or 28x28 matrices (False) |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
73 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
74 """ |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
75 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
76 x_fn, y_fn = _split_fns[split] |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
77 |
878
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
78 x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), |
faa9f880d0d2
fixes to dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
874
diff
changeset
|
79 single_shape=(1024,))(s_idx) |
931
1c62fa857cab
forcing int32 label dtype in shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
879
diff
changeset
|
80 y = TensorFnDataset(dtype='int32', bcast=(), fn=y_fn)(s_idx) |
874
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
81 if x.ndim == 1: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
82 if not rasterized: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
83 x = x.reshape((32,32)) |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
84 elif x.ndim == 2: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
85 if not rasterized: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
86 x = x.reshape((x.shape[0], 32,32)) |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
87 else: |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
88 assert False, 'what happened?' |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
89 return x, y |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
90 nclasses = 10 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
91 |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
92 def glviewer(split='train'): |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
93 from glviewer import GlViewer |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
94 i = theano.tensor.iscalar() |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
95 f = theano.function([i], shapeset1(i, split, dtype='uint8', rasterized=False)[0]) |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
96 GlViewer(f).main() |
76f71e10f5ef
added dataset_ops.shapeset1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
97 |