annotate pylearn/datasets/nist_digits.py @ 1484:83d3c9ee6d65

* changed MNIST dataset to use config.get_filepath_in_roots mechanism
author gdesjardins
date Tue, 05 Jul 2011 11:01:51 -0400
parents 87d416e1f4fd
children
rev   line source
928
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
1 """
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
2 Provides a Dataset to access the nist digits dataset.
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
3 """
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
4
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
5 import os, numpy
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
6 from pylearn.io import filetensor as ft
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
7 from pylearn.datasets.config import data_root # config
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
8 from pylearn.datasets.dataset import Dataset
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
9
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
10 from pylearn.datasets.nist_sd import nist_to_float_11, nist_to_float_01
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
11
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
12
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
13 def load(dataset = 'train', attribute = 'data'):
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
14 """Load the filetensor corresponding to the set and attribute.
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
15
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
16 :param dataset: str that is 'train', 'valid' or 'test'
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
17 :param attribute: str that is 'data' or 'labels'
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
18 """
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
19 fn = 'digits_' + dataset + '_' + attribute + '.ft'
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
20 fn = os.path.join(data_root(), 'nist', 'by_class', 'digits', fn)
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
21
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
22 fd = open(fn)
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
23 data = ft.read(fd)
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
24 fd.close()
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
25
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
26 return data
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
27
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
28 def train_valid_test(ntrain=285661, nvalid=58646, ntest=58646, path=None,
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
29 range = '01'):
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
30 """
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
31 Load the nist digits dataset as a Dataset.
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
32
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
33 @note: the examples are uint8 and the labels are int32.
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
34 @todo: possibility of loading part of the data.
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
35 """
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
36 rval = Dataset()
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
37
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
38 #
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
39 rval.n_classes = 10
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
40 rval.img_shape = (32,32)
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
41
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
42 if range == '01':
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
43 rval.preprocess = nist_to_float_01
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
44 elif range == '11':
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
45 rval.preprocess = nist_to_float_11
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
46 else:
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
47 raise ValueError('Nist Digits dataset does not support range = %s' % range)
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
48 print "Nist Digits dataset: using preproc will provide inputs in the %s range." \
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
49 % range
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
50
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
51 # train
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
52 examples = load(dataset = 'train', attribute = 'data')
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
53 labels = load(dataset = 'train', attribute = 'labels')
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
54 rval.train = Dataset.Obj(x=examples[:ntrain], y=labels[:ntrain])
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
55
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
56 # valid
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
57 rval.valid = Dataset.Obj(x=examples[285661:285661+nvalid], y=labels[285661:285661+nvalid])
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
58
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
59 # test
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
60 examples = load(dataset = 'test', attribute = 'data')
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
61 labels = load(dataset = 'test', attribute = 'labels')
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
62 rval.test = Dataset.Obj(x=examples[:ntest], y=labels[:ntest])
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
63
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
64 return rval
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
65
87d416e1f4fd Add nist_digits dataset.
Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
parents:
diff changeset
66