Mercurial > pylearn
view pylearn/datasets/nist_sd.py @ 1479:1b69d435f09f
fix error string.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Wed, 25 May 2011 09:26:47 -0400 |
parents | 2e87264493ef |
children |
line wrap: on
line source
""" Provides a Dataset to access the nist digits_reshuffled dataset. """ import os, numpy from pylearn.io import filetensor as ft from pylearn.datasets.config import data_root # config from pylearn.datasets.dataset import Dataset def nist_to_float_11(x): return (x - 128.0)/ 128.0 def nist_to_float_01(x): return x / 255.0 def load(dataset = 'train', attribute = 'data'): """Load the filetensor corresponding to the set and attribute. :param dataset: str that is 'train', 'valid' or 'test' :param attribute: str that is 'data' or 'labels' """ fn = 'digits_reshuffled_' + dataset + '_' + attribute + '.ft' fn = os.path.join(data_root(), 'nist', 'by_class', 'digits_reshuffled', fn) fd = open(fn) data = ft.read(fd) fd.close() return data def train_valid_test(ntrain=285661, nvalid=58646, ntest=58646, path=None, range = '01'): """ Load the nist reshuffled digits dataset as a Dataset. @note: the examples are uint8 and the labels are int32. @todo: possibility of loading part of the data. """ rval = Dataset() # rval.n_classes = 10 rval.img_shape = (32,32) if range == '01': rval.preprocess = nist_to_float_01 elif range == '11': rval.preprocess = nist_to_float_11 else: raise ValueError('Nist SD dataset does not support range = %s' % range) print "Nist SD dataset: using preproc will provide inputs in the %s range." \ % range # train examples = load(dataset = 'train', attribute = 'data') labels = load(dataset = 'train', attribute = 'labels') rval.train = Dataset.Obj(x=examples[:ntrain], y=labels[:ntrain]) # valid examples = load(dataset = 'valid', attribute = 'data') labels = load(dataset = 'valid', attribute = 'labels') rval.valid = Dataset.Obj(x=examples[:nvalid], y=labels[:nvalid]) # test examples = load(dataset = 'test', attribute = 'data') labels = load(dataset = 'test', attribute = 'labels') rval.test = Dataset.Obj(x=examples[:ntest], y=labels[:ntest]) return rval