annotate pylearn/dataset_ops/cifar10.py @ 1476:8c10bda4bb5f

Configured default train/valid/test split for icml07.MNIST_rotated_background dataset. Defaults are the ones used by Hugo in the ICML07 paper and in all contracting auto-encoder papers.
author gdesjardins
date Fri, 20 May 2011 16:53:00 -0400
parents d9dd09a2ee90
children
rev   line source
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
1 """
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
2 CIFAR-10 dataset of labeled small colour images.
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
3
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 For details see either:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
5
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6 - http://www.cs.toronto.edu/~kriz/cifar.html, or
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
7
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
8 - /data/lisa/data/cifar10/cifar-10-batches-py/readme.html
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
9
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
10 """
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
11 import cPickle, os, sys, numpy
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
12 from pylearn.datasets.config import data_root
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
13 import theano
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
14
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
15 from protocol import TensorFnDataset # protocol.py __init__.py
865
49c1035fe582 added code comment for vim shortcut
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 850
diff changeset
16 from .memo import memo # memo.py
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
17
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
18 def _unpickle(filename, dtype):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
19 #implements loading as well as dtype-conversion and dtype-scaling
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
20 fo = open(filename, 'rb')
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
21 dict = cPickle.load(fo)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
22 fo.close()
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
23 data, labels = numpy.asarray(dict['data'], dtype=dtype), numpy.asarray(dict['labels'], dtype='int32')
1288
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
24 if str(dtype) in ('float32', 'float64'):
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
25 data /= 255
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
26 return data, labels
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
27
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
28 @memo
1288
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
29 def all_data_labels(dtype='uint8'):
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
30 train_batch_data, train_batch_labels = zip(*[ _unpickle( os.path.join(data_root(), 'cifar10',
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
31 'cifar-10-batches-py', 'data_batch_%i'%i), dtype) for i in range(1,6)])
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
32 test_batch_data, test_batch_labels = test_data_labels(dtype)
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
33 data = numpy.vstack(list(train_batch_data)+[test_batch_data])
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
34 labels = numpy.hstack(list(train_batch_labels)+[test_batch_labels])
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
35 return data, labels
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
36
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
37 @memo
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
38 def train_data_labels(dtype='uint8'):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
39 batch_data, batch_labels = zip(*[ _unpickle( os.path.join(data_root(), 'cifar10',
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
40 'cifar-10-batches-py', 'data_batch_%i'%i), dtype) for i in range(1,6)])
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
41 data = numpy.vstack(batch_data)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
42 labels = numpy.hstack(batch_labels)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
43 return data, labels
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
44 @memo
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
45 def test_data_labels(dtype='uint8'):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
46 return _unpickle(os.path.join(data_root(), 'cifar10', 'cifar-10-batches-py', 'test_batch'),
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
47 dtype)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
48
845
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
49 def forget():
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
50 train_data_labels.forget()
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
51 test_data_labels.forget()
1288
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
52 all_data_labels.forget()
845
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
53
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
54
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
55 # functions for TensorFnDataset
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
56
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
57 def train_data(dtype):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
58 return train_data_labels(dtype)[0][:40000]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
59 def train_labels():
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
60 return train_data_labels()[1][:40000]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
61 def valid_data(dtype):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
62 return train_data_labels(dtype)[0][40000:]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
63 def valid_labels():
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
64 return train_data_labels()[1][40000:]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
65 def test_data(dtype):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
66 return test_data_labels(dtype)[0]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
67 def test_labels():
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
68 return test_data_labels()[1]
1288
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
69 def all_data(dtype):
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
70 if dtype!='uint8':
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
71 raise ValueError()
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
72 return all_data_labels()[0]
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
73 def all_labels():
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
74 return all_data_labels()[1]
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
75
1453
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
76 split_sizes = dict(
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
77 train=40000,
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
78 valid=10000,
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
79 test=10000,
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
80 all=60000)
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
81
1282
f36f59e53c28 cifar10 op - made splits constructors a parameter
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 951
diff changeset
82 def cifar10(s_idx, split, dtype='float64', rasterized=False, color='grey',
f36f59e53c28 cifar10 op - made splits constructors a parameter
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 951
diff changeset
83 split_options = {'train':(train_data, train_labels),
f36f59e53c28 cifar10 op - made splits constructors a parameter
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 951
diff changeset
84 'valid': (valid_data, valid_labels),
1288
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
85 'test': (test_data, test_labels),
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
86 'all': (all_data, all_labels),
1453
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
87 },
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
88 loop=False
1282
f36f59e53c28 cifar10 op - made splits constructors a parameter
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 951
diff changeset
89 ):
1453
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
90 """
951
5d70dfc70ec0 added comments and image-rendering code to cifar-10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 865
diff changeset
91 Returns a pair (img, label) of theano expressions for cifar-10 samples
5d70dfc70ec0 added comments and image-rendering code to cifar-10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 865
diff changeset
92
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
93 :param s_idx: the indexes
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
94
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
95 :param split:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
96
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
97 :param dtype:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
98
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
99 :param rasterized: return examples as vectors (True) or 28x28 matrices (False)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
100
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
101 :param color: control how to deal with the color in the images'
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
102 - grey greyscale (with luminance weighting)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
103 - rgb add a trailing dimension of length 3 with rgb colour channels
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
104
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
105 """
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
106
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
107 if split not in split_options:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
108 raise ValueError('invalid split option', (split, split_options.keys()))
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
109
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
110 color_options = ('grey', 'rgb')
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
111 if color not in color_options:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
112 raise ValueError('invalid color option', (color, color_options))
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
113
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
114 x_fn, y_fn = split_options[split]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
115
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
116 x_op = TensorFnDataset(dtype, (False,), (x_fn, (dtype,)), (3072,))
850
1bdfef116a61 Small type fix
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents: 845
diff changeset
117 y_op = TensorFnDataset('int32', (), y_fn)
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
118
1453
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
119 if loop:
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
120 s_idx = s_idx % split_sizes[split]
d9dd09a2ee90 added loop argument to cifar10 op constructor
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1400
diff changeset
121
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
122 x = x_op(s_idx)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
123 y = y_op(s_idx)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
124
1288
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
125 if color=='grey':
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
126 # Y = 0.3R + 0.59G + 0.11B from
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
127 # http://gimp-savvy.com/BOOK/index.html?node54.html
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
128 rgb_dtype = 'float32'
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
129 if dtype == 'float64':
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
130 rgb_dtype = dtype
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
131 r = numpy.asarray(.3, dtype=rgb_dtype)
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
132 g = numpy.asarray(.59, dtype=rgb_dtype)
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
133 b = numpy.asarray(.11, dtype=rgb_dtype)
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
134
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
135 if x.ndim == 1:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
136 if rasterized:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
137 if color=='grey':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
138 x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
139 if dtype=='uint8':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
140 x = theano.tensor.cast(x, 'uint8')
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
141 elif color=='rgb':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
142 # the strides aren't what you'd expect between channels,
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
143 # but theano is all about weird strides
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
144 x = x.reshape((3,32*32)).T
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
145 else:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
146 raise NotImplemented('color', color)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
147 else:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
148 if color=='grey':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
149 x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
150 if dtype=='uint8':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
151 x = theano.tensor.cast(x, 'uint8')
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
152 x = x.reshape((32,32))
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
153 elif color=='rgb':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
154 # the strides aren't what you'd expect between channels,
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
155 # but theano is all about weird strides
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
156 x = x.reshape((3,32,32)).dimshuffle(1, 2, 0)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
157 else:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
158 raise NotImplemented('color', color)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
159 elif x.ndim == 2:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
160 N = x.shape[0] # symbolic
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
161 if rasterized:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
162 if color=='grey':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
163 x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
164 if dtype=='uint8':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
165 x = theano.tensor.cast(x, 'uint8')
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
166 elif color=='rgb':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
167 # the strides aren't what you'd expect between channels,
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
168 # but theano is all about weird strides
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
169 x = x.reshape((N, 3,32*32)).dimshuffle(0, 2, 1)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
170 else:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
171 raise NotImplemented('color', color)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
172 else:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
173 if color=='grey':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
174 x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
175 if dtype=='uint8':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
176 x = theano.tensor.cast(x, 'uint8')
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
177 x.reshape((N, 32, 32))
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
178 elif color=='rgb':
1288
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
179 # note: the strides aren't what you'd expect between channels,
a165f2666643 cifar10 - added support for "all" split
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1283
diff changeset
180 # but a copy of the data would correct that.
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
181 x = x.reshape((N,3,32,32)).dimshuffle(0, 2, 3, 1)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
182 else:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
183 raise NotImplemented('color', color)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
184 else:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
185 raise ValueError('x has too many dimensions', x.ndim)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
186
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
187 return x, y
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
188
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
189 nclasses = 10
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
190
1283
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
191 import pylearn.datasets.image_patches
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
192 import pylearn.preprocessing.pca
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
193
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
194 @memo
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
195 def random_cifar_patches(dtype, N,R,C, centered=True):
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
196 #These used to be arguments, but optional arguments don't work well with the cache
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
197 # because the cache doesn't [yet] look up what they are
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
198 rng_seed=89234
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
199 channel_rank=2
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
200
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
201 rng=numpy.random.RandomState(rng_seed)
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
202 imgs = train_data_labels(dtype)[0][:40000].reshape((40000,3,32,32)).transpose((0,2,3,1))
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
203 forget() #un-cache the original images
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
204 #import pdb; pdb.set_trace()
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
205 patches = pylearn.datasets.image_patches.extract_random_patches(imgs, N,R,C, rng)
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
206 orig_shape = patches.shape
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
207
1400
08a00dea117d added some comments in dataset_ops/cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1288
diff changeset
208 # center individual examples (subtract off mean colour)
1283
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
209 patches = patches.reshape((orig_shape[0], R*C*3))
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
210 patches -= patches.mean(axis=1).reshape((N, 1))
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
211 patches = patches.reshape(orig_shape)
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
212
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
213 if channel_rank==4:
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
214 pass
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
215 elif channel_rank==2:
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
216 # put the channels the cifar10 way :/
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
217 patches = patches.transpose((0,3,1,2)).copy()
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
218 else:
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
219 raise NotImplementedError()
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
220 if centered:
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
221 patches -= patches.mean(axis=0)
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
222 return patches
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
223
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
224 @memo
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
225 def random_cifar_patches_pca(max_components, max_energy_fraction, dtype, N,R,C,*args):
1400
08a00dea117d added some comments in dataset_ops/cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1288
diff changeset
226 """
08a00dea117d added some comments in dataset_ops/cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1288
diff changeset
227 Return (eigvals, eigvecs) of centered patches from the training data.
08a00dea117d added some comments in dataset_ops/cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1288
diff changeset
228 """
1283
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
229 pca, _ = pylearn.preprocessing.pca.pca_from_examples(
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
230 random_cifar_patches(dtype,N,R,C,*args).reshape((N,R*C*3)),
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
231 max_components, max_energy_fraction, x_centered=True)
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
232 return pca
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
233
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
234 @memo
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
235 def whitened_random_cifar_patches(max_components, max_energy_fraction, dtype,N,R,C,*args):
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
236 pca = random_cifar_patches_pca(max_components, max_energy_fraction, dtype,N,R,C,*args)
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
237 patches = random_cifar_patches(dtype,N,R,C,*args).reshape((N,R*C*3))
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
238 random_cifar_patches.forget() #un-cache the original patches
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
239 return pylearn.preprocessing.pca.pca_whiten(pca, patches).astype(dtype)
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
240
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
241 def cifar10_patches(s_idx, split, dtype='float32', rasterized=True, color='rgb',
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
242 n_patches=1000, patch_size=(8,8), pca_components=80):
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
243 """
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
244 Return
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
245 """
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
246 if split != 'train': raise NotImplementedError()
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
247 if dtype != 'float32':raise NotImplementedError()
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
248 if color != 'rgb': raise NotImplementedError()
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
249 if s_idx.ndim != 1: raise NotImplementedError()
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
250
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
251 x_op = TensorFnDataset(dtype, (False,),
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
252 (whitened_random_cifar_patches, (
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
253 pca_components,None,dtype,n_patches, patch_size[0], patch_size[1])),
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
254 (patch_size[0],patch_size[1],3))
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
255 x = x_op(s_idx%n_patches)
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
256
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
257 if rasterized:
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
258 x = x.flatten(2)
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
259 else:
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
260 raise NotImplementedError()
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
261
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
262 return x
a73db8d65abb cifar10 op - added an op for generating whitened patches
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 1282
diff changeset
263
845
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
264 def glviewer(split='train'):
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
265 from glviewer import GlViewer
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
266 i = theano.tensor.iscalar()
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
267 f = theano.function([i], cifar10(i, split, dtype='uint8', rasterized=False, color='rgb')[0])
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
268 GlViewer(f).main()
825358a8072f added glviewer to cifar10
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 838
diff changeset
269
838
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
270
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
271 if 0:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
272 def datarow_to_greyscale_28by28(row, max_scale=1.0):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
273 assert row.shape == (3072,)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
274 rgb = row.reshape((3, 1024))
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
275 grey = numpy.mean(rgb, axis=0) * max_scale / 255.0
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
276 assert grey.shape == (1024,)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
277 grey_arr = grey.reshape((32,32))
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
278 middle = grey_arr[1:29,1:29]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
279 middle_flat = middle.reshape((784,))
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
280 return middle_flat
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
281
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
282
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
283 def batch_iter(b_idx, max_scale=1.0):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
284 if b_idx == 'test':
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
285 data, labels = data_batch_test()
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
286 else:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
287 data, labels = data_batches[b_idx]()
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
288 assert len(data) == 10000
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
289 assert len(labels) == 10000
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
290
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
291 for i in xrange(len(labels)):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
292 yield datarow_to_greyscale_28by28(data[i], max_scale=max_scale), labels[i]
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
293
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
294 def train_iter(scale=1.0):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
295 while True:
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
296 for b_idx in xrange(4):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
297 for d, l in batch_iter(b_idx, max_scale=scale):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
298 yield d, l
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
299 def valid_iter(scale=1.0):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
300 for d, l in batch_iter(4, max_scale=scale):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
301 yield d, l
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
302 def test_iter(scale=1.0):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
303 for d, l in batch_iter('test', max_scale=scale):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
304 yield d, l
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
305
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
306
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
307
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
308 # the following function is patterned after the MNIST.mnist function
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
309 class GreyScale(theano.Op):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
310 def __eq__(self, other):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
311 return type(self) == type(other)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
312
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
313 def __hash__(self):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
314 return hash(type(self))
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
315
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
316 def make_node(self, x):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
317 x_ = theano.tensor.as_tensor_variable(x)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
318 if x_.type.ndim not in (3,4):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
319 raise TypeError('Greyscaling a tensor with unexpected number of dimensions',
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
320 x_.type.ndim)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
321 z_type = theano.tensor.TensorType(
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
322 dtype=x.dtype,
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
323 broadcastable = x_.type.broadcastable[:-1])
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
324 return theano.Apply(self, [x_], [z_type()])
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
325
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
326 def perform(self, node, (x,), (z,)):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
327 #TODO: Use PIL for real greyscale
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
328 z[0] = numpy.asarray(x.mean(axis=x.ndim-1), dtype=node.outputs[0].type.dtype)
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
329
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
330 def grad(self, (x,), (z,)):
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
331 # TODO: this op is actually differentiable...
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
332 # when perform is done with PIL, then TODO is to look up the constants of the RGB
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
333 # weights, and put them here in the grad function.
4f7e0edee7d0 adding cifar10 dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
334 return [None]