changeset 928:87d416e1f4fd

Add nist_digits dataset.
author Pierre-Antoine Manzagol <pierre.antoine.manzagol@gmail.com>
date Fri, 09 Apr 2010 11:14:24 -0400
parents ffaf94da8100
children 34d1cd516f76
files pylearn/datasets/nist_digits.py
diffstat 1 files changed, 66 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/datasets/nist_digits.py	Fri Apr 09 11:14:24 2010 -0400
@@ -0,0 +1,66 @@
+"""
+Provides a Dataset to access the nist digits dataset. 
+"""
+
+import os, numpy
+from pylearn.io import filetensor as ft
+from pylearn.datasets.config import data_root # config
+from pylearn.datasets.dataset import Dataset
+
+from pylearn.datasets.nist_sd import nist_to_float_11, nist_to_float_01
+
+
+def load(dataset = 'train', attribute = 'data'):
+  """Load the filetensor corresponding to the set and attribute.
+
+  :param dataset: str that is 'train', 'valid' or 'test'
+  :param attribute: str that is 'data' or 'labels'
+  """
+  fn = 'digits_' + dataset + '_' + attribute + '.ft'
+  fn = os.path.join(data_root(), 'nist', 'by_class', 'digits', fn)
+
+  fd = open(fn)
+  data = ft.read(fd)
+  fd.close()
+
+  return data
+
+def train_valid_test(ntrain=285661, nvalid=58646, ntest=58646, path=None,
+    range = '01'):
+  """
+  Load the nist digits dataset as a Dataset.
+
+  @note: the examples are uint8 and the labels are int32.
+  @todo: possibility of loading part of the data.
+  """
+  rval = Dataset()
+
+  # 
+  rval.n_classes = 10
+  rval.img_shape = (32,32)
+
+  if range == '01':
+    rval.preprocess = nist_to_float_01
+  elif range == '11':
+    rval.preprocess = nist_to_float_11
+  else:
+    raise ValueError('Nist Digits dataset does not support range = %s' % range)
+  print "Nist Digits dataset: using preproc will provide inputs in the %s range." \
+      % range
+
+  # train
+  examples = load(dataset = 'train', attribute = 'data')
+  labels = load(dataset = 'train', attribute = 'labels')
+  rval.train = Dataset.Obj(x=examples[:ntrain], y=labels[:ntrain])
+
+  # valid
+  rval.valid = Dataset.Obj(x=examples[285661:285661+nvalid], y=labels[285661:285661+nvalid])
+
+  # test
+  examples = load(dataset = 'test', attribute = 'data')
+  labels = load(dataset = 'test', attribute = 'labels')
+  rval.test = Dataset.Obj(x=examples[:ntest], y=labels[:ntest])
+  
+  return rval
+
+