Mercurial > pylearn
changeset 899:2e87264493ef
NistSD dataset: Add range argument for input [-1,1] or [0,1].
The train_valid_test function now takes a range argument that determines
which preproc function to set. Valid ranges are '01' (for use with sigmoid)
and '11' (ie [-1,1], for use with tanh).
author | Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com> |
---|---|
date | Thu, 11 Feb 2010 09:44:34 -0500 |
parents | cdbfdbf7ec56 |
children | f36d68b65f2c |
files | pylearn/datasets/nist_sd.py |
diffstat | 1 files changed, 15 insertions(+), 3 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/datasets/nist_sd.py Tue Feb 09 22:11:52 2010 -0500 +++ b/pylearn/datasets/nist_sd.py Thu Feb 11 09:44:34 2010 -0500 @@ -7,9 +7,12 @@ from pylearn.datasets.config import data_root # config from pylearn.datasets.dataset import Dataset -def nist_to_float(x): +def nist_to_float_11(x): return (x - 128.0)/ 128.0 +def nist_to_float_01(x): + return x / 255.0 + def load(dataset = 'train', attribute = 'data'): """Load the filetensor corresponding to the set and attribute. @@ -25,7 +28,8 @@ return data -def train_valid_test(ntrain=285661, nvalid=58646, ntest=58646, path=None): +def train_valid_test(ntrain=285661, nvalid=58646, ntest=58646, path=None, + range = '01'): """ Load the nist reshuffled digits dataset as a Dataset. @@ -37,7 +41,15 @@ # rval.n_classes = 10 rval.img_shape = (32,32) - rval.preprocess = nist_to_float + + if range == '01': + rval.preprocess = nist_to_float_01 + elif range == '11': + rval.preprocess = nist_to_float_11 + else: + raise ValueError('Nist SD dataset does not support range = %s' % range) + print "Nist SD dataset: using preproc will provide inputs in the %s range." \ + % range # train examples = load(dataset = 'train', attribute = 'data')