view datasets/defs.py @ 239:42005ec87747

Mergé (manuellement) les changements de Sylvain pour utiliser le code de dataset d'Arnaud, à cette différence près que je n'utilse pas les givens. J'ai probablement une approche différente pour limiter la taille du dataset dans mon débuggage, aussi.
author fsavard
date Mon, 15 Mar 2010 18:30:21 -0400
parents 6f4e3719a3cc
children 966272e7f14b
line wrap: on
line source

__all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr', 
           'nist_P07', 'mnist']

from ftfile import FTDataSet
from gzpklfile import GzpklDataSet
import theano
import os

# if the environmental variables exist, get the path from them, 
# otherwise fall back on the default
NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/')
DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/')

nist_digits = FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
                        train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')],
                        test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')],
                        test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')],
                        indtype=theano.config.floatX, inscale=255.)
nist_lower = FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
                        train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')],
                        test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')],
                        test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')],
                        indtype=theano.config.floatX, inscale=255.)
nist_upper = FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
                        train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')],
                        test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')],
                        test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')],
                        indtype=theano.config.floatX, inscale=255.)

nist_all = FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
                     train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')],
                     test_data = [os.path.join(DATA_PATH,'test_data.ft')],
                     test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')],
                     valid_data = [os.path.join(DATA_PATH,'valid_data.ft')],
                     valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')],
                     indtype=theano.config.floatX, inscale=255.)

ocr = FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
                train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')],
                test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')],
                test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')],
                valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')],
                valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')],
                indtype=theano.config.floatX, inscale=255.)

nist_P07 = FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
                     train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)],
                     test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')],
                     test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')],
                     valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')],
                     valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')],
                     indtype=theano.config.floatX, inscale=255.)

mnist = GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'))