changeset 899:2e87264493ef

NistSD dataset: Add range argument for input [-1,1] or [0,1]. The train_valid_test function now takes a range argument that determines which preproc function to set. Valid ranges are '01' (for use with sigmoid) and '11' (ie [-1,1], for use with tanh).
author Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
date Thu, 11 Feb 2010 09:44:34 -0500
parents cdbfdbf7ec56
children f36d68b65f2c
files pylearn/datasets/nist_sd.py
diffstat 1 files changed, 15 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/nist_sd.py	Tue Feb 09 22:11:52 2010 -0500
+++ b/pylearn/datasets/nist_sd.py	Thu Feb 11 09:44:34 2010 -0500
@@ -7,9 +7,12 @@
 from pylearn.datasets.config import data_root # config
 from pylearn.datasets.dataset import Dataset
 
-def nist_to_float(x):
+def nist_to_float_11(x):
   return (x - 128.0)/ 128.0
 
+def nist_to_float_01(x):
+  return x / 255.0
+
 def load(dataset = 'train', attribute = 'data'):
   """Load the filetensor corresponding to the set and attribute.
 
@@ -25,7 +28,8 @@
 
   return data
 
-def train_valid_test(ntrain=285661, nvalid=58646, ntest=58646, path=None):
+def train_valid_test(ntrain=285661, nvalid=58646, ntest=58646, path=None,
+    range = '01'):
   """
   Load the nist reshuffled digits dataset as a Dataset.
 
@@ -37,7 +41,15 @@
   # 
   rval.n_classes = 10
   rval.img_shape = (32,32)
-  rval.preprocess = nist_to_float
+
+  if range == '01':
+    rval.preprocess = nist_to_float_01
+  elif range == '11':
+    rval.preprocess = nist_to_float_11
+  else:
+    raise ValueError('Nist SD dataset does not support range = %s' % range)
+  print "Nist SD dataset: using preproc will provide inputs in the %s range." \
+      % range
 
   # train
   examples = load(dataset = 'train', attribute = 'data')