Mercurial > pylearn
changeset 1288:a165f2666643
cifar10 - added support for "all" split
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 29 Sep 2010 18:35:40 -0400 |
parents | 4fa2a32e8fde |
children | 092cd4cd2009 |
files | pylearn/dataset_ops/cifar10.py |
diffstat | 1 files changed, 31 insertions(+), 12 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/dataset_ops/cifar10.py Wed Sep 29 18:34:47 2010 -0400 +++ b/pylearn/dataset_ops/cifar10.py Wed Sep 29 18:35:40 2010 -0400 @@ -21,11 +21,20 @@ dict = cPickle.load(fo) fo.close() data, labels = numpy.asarray(dict['data'], dtype=dtype), numpy.asarray(dict['labels'], dtype='int32') - if dtype in ('float32', 'float64'): + if str(dtype) in ('float32', 'float64'): data /= 255 return data, labels @memo +def all_data_labels(dtype='uint8'): + train_batch_data, train_batch_labels = zip(*[ _unpickle( os.path.join(data_root(), 'cifar10', + 'cifar-10-batches-py', 'data_batch_%i'%i), dtype) for i in range(1,6)]) + test_batch_data, test_batch_labels = test_data_labels(dtype) + data = numpy.vstack(list(train_batch_data)+[test_batch_data]) + labels = numpy.hstack(list(train_batch_labels)+[test_batch_labels]) + return data, labels + +@memo def train_data_labels(dtype='uint8'): batch_data, batch_labels = zip(*[ _unpickle( os.path.join(data_root(), 'cifar10', 'cifar-10-batches-py', 'data_batch_%i'%i), dtype) for i in range(1,6)]) @@ -40,6 +49,7 @@ def forget(): train_data_labels.forget() test_data_labels.forget() + all_data_labels.forget() # functions for TensorFnDataset @@ -56,12 +66,20 @@ return test_data_labels(dtype)[0] def test_labels(): return test_data_labels()[1] +def all_data(dtype): + if dtype!='uint8': + raise ValueError() + return all_data_labels()[0] +def all_labels(): + return all_data_labels()[1] def cifar10(s_idx, split, dtype='float64', rasterized=False, color='grey', split_options = {'train':(train_data, train_labels), 'valid': (valid_data, valid_labels), - 'test': (test_data, test_labels)} + 'test': (test_data, test_labels), + 'all': (all_data, all_labels), + } ): """ Returns a pair (img, label) of theano expressions for cifar-10 samples @@ -95,14 +113,15 @@ x = x_op(s_idx) y = y_op(s_idx) - # Y = 0.3R + 0.59G + 0.11B from - # http://gimp-savvy.com/BOOK/index.html?node54.html - rgb_dtype = 'float32' - if dtype == 'float64': - rgb_dtype = dtype - r = numpy.asarray(.3, dtype=rgb_dtype) - g = numpy.asarray(.59, dtype=rgb_dtype) - b = numpy.asarray(.11, dtype=rgb_dtype) + if color=='grey': + # Y = 0.3R + 0.59G + 0.11B from + # http://gimp-savvy.com/BOOK/index.html?node54.html + rgb_dtype = 'float32' + if dtype == 'float64': + rgb_dtype = dtype + r = numpy.asarray(.3, dtype=rgb_dtype) + g = numpy.asarray(.59, dtype=rgb_dtype) + b = numpy.asarray(.11, dtype=rgb_dtype) if x.ndim == 1: if rasterized: @@ -148,8 +167,8 @@ x = theano.tensor.cast(x, 'uint8') x.reshape((N, 32, 32)) elif color=='rgb': - # the strides aren't what you'd expect between channels, - # but theano is all about weird strides + # note: the strides aren't what you'd expect between channels, + # but a copy of the data would correct that. x = x.reshape((N,3,32,32)).dimshuffle(0, 2, 3, 1) else: raise NotImplemented('color', color)