view datasets/defs.py @ 222:4cfd0eb438af

Add mnist to datasets (and supporting code).
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 11 Mar 2010 14:41:31 -0500
parents 476da2ba6a12
children 6f4e3719a3cc
line wrap: on
line source

__all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr', 
           'nist_P07', 'mnist']

from ftfile import FTDataSet
from gzpklfile import GzpklDataSet
import theano

NIST_PATH = '/data/lisa/data/nist/by_class/'
DATA_PATH = '/data/lisa/data/ift6266h10/'

nist_digits = FTDataSet(train_data = [NIST_PATH+'digits/digits_train_data.ft'],
                        train_lbl = [NIST_PATH+'digits/digits_train_labels.ft'],
                        test_data = [NIST_PATH+'digits/digits_test_data.ft'],
                        test_lbl = [NIST_PATH+'digits/digits_test_labels.ft'],
                        indtype=theano.config.floatX, inscale=255.)
nist_lower = FTDataSet(train_data = [NIST_PATH+'lower/lower_train_data.ft'],
                        train_lbl = [NIST_PATH+'lower/lower_train_labels.ft'],
                        test_data = [NIST_PATH+'lower/lower_test_data.ft'],
                        test_lbl = [NIST_PATH+'lower/lower_test_labels.ft'],
                        indtype=theano.config.floatX, inscale=255.)
nist_upper = FTDataSet(train_data = [NIST_PATH+'upper/upper_train_data.ft'],
                        train_lbl = [NIST_PATH+'upper/upper_train_labels.ft'],
                        test_data = [NIST_PATH+'upper/upper_test_data.ft'],
                        test_lbl = [NIST_PATH+'upper/upper_test_labels.ft'],
                        indtype=theano.config.floatX, inscale=255.)

nist_all = FTDataSet(train_data = [DATA_PATH+'train_data.ft'],
                     train_lbl = [DATA_PATH+'train_labels.ft'],
                     test_data = [DATA_PATH+'test_data.ft'],
                     test_lbl = [DATA_PATH+'test_labels.ft'],
                     valid_data = [DATA_PATH+'valid_data.ft'],
                     valid_lbl = [DATA_PATH+'valid_labels.ft'],
                     indtype=theano.config.floatX, inscale=255.)

ocr = FTDataSet(train_data = [DATA_PATH+'ocr_train_data.ft'],
                train_lbl = [DATA_PATH+'ocr_train_labels.ft'],
                test_data = [DATA_PATH+'ocr_test_data.ft'],
                test_lbl = [DATA_PATH+'ocr_test_labels.ft'],
                valid_data = [DATA_PATH+'ocr_valid_data.ft'],
                valid_lbl = [DATA_PATH+'ocr_valid_labels.ft'],
                indtype=theano.config.floatX, inscale=255.)

nist_P07 = FTDataSet(train_data = [DATA_PATH+'data/P07_train'+str(i)+'_data.ft' for i in range(100)],
                     train_lbl = [DATA_PATH+'data/P07_train'+str(i)+'_labels.ft' for i in range(100)],
                     test_data = [DATA_PATH+'data/P07_test_data.ft'],
                     test_lbl = [DATA_PATH+'data/P07_test_labels.ft'],
                     valid_data = [DATA_PATH+'data/P07_valid_data.ft'],
                     valid_lbl = [DATA_PATH+'data/P07_valid_labels.ft'],
                     indtype=theano.config.floatX, inscale=255.)

mnist = GzpklDataSet(DATA_PATH+'mnist.pkl.gz')