# HG changeset patch # User Pierre-Antoine Manzagol # Date 1265899474 18000 # Node ID 2e87264493efbbfead8231442266a8008c05dcec # Parent cdbfdbf7ec56efd4c3fe00f1cb6cf84bfe18ea18 NistSD dataset: Add range argument for input [-1,1] or [0,1]. The train_valid_test function now takes a range argument that determines which preproc function to set. Valid ranges are '01' (for use with sigmoid) and '11' (ie [-1,1], for use with tanh). diff -r cdbfdbf7ec56 -r 2e87264493ef pylearn/datasets/nist_sd.py --- a/pylearn/datasets/nist_sd.py Tue Feb 09 22:11:52 2010 -0500 +++ b/pylearn/datasets/nist_sd.py Thu Feb 11 09:44:34 2010 -0500 @@ -7,9 +7,12 @@ from pylearn.datasets.config import data_root # config from pylearn.datasets.dataset import Dataset -def nist_to_float(x): +def nist_to_float_11(x): return (x - 128.0)/ 128.0 +def nist_to_float_01(x): + return x / 255.0 + def load(dataset = 'train', attribute = 'data'): """Load the filetensor corresponding to the set and attribute. @@ -25,7 +28,8 @@ return data -def train_valid_test(ntrain=285661, nvalid=58646, ntest=58646, path=None): +def train_valid_test(ntrain=285661, nvalid=58646, ntest=58646, path=None, + range = '01'): """ Load the nist reshuffled digits dataset as a Dataset. @@ -37,7 +41,15 @@ # rval.n_classes = 10 rval.img_shape = (32,32) - rval.preprocess = nist_to_float + + if range == '01': + rval.preprocess = nist_to_float_01 + elif range == '11': + rval.preprocess = nist_to_float_11 + else: + raise ValueError('Nist SD dataset does not support range = %s' % range) + print "Nist SD dataset: using preproc will provide inputs in the %s range." \ + % range # train examples = load(dataset = 'train', attribute = 'data')