view datasets/defs.py @ 266:1e4e60ddadb1

Merge. Ah, et dans le dernier commit, j'avais oublié de mentionner que j'ai ajouté du code pour gérer l'isolation de différents clones pour rouler des expériences et modifier le code en même temps.
author fsavard
date Fri, 19 Mar 2010 10:56:16 -0400
parents 966272e7f14b
children 4533350d7361
line wrap: on
line source

__all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr', 
           'nist_P07', 'mnist']

from ftfile import FTDataSet
from gzpklfile import GzpklDataSet
import theano
import os

# if the environmental variables exist, get the path from them, 
# otherwise fall back on the default
NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/')
DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/')

nist_digits = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
                        train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')],
                        test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')],
                        test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')],
                        indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
nist_lower = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
                        train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')],
                        test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')],
                        test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')],
                        indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
nist_upper = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
                        train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')],
                        test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')],
                        test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')],
                        indtype=theano.config.floatX, inscale=255., maxsize=maxsize)

nist_all = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
                     train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')],
                     test_data = [os.path.join(DATA_PATH,'test_data.ft')],
                     test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')],
                     valid_data = [os.path.join(DATA_PATH,'valid_data.ft')],
                     valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')],
                     indtype=theano.config.floatX, inscale=255., maxsize=maxsize)

ocr = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
                train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')],
                test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')],
                test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')],
                valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')],
                valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')],
                indtype=theano.config.floatX, inscale=255., maxsize=maxsize)

nist_P07 = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
                     train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)],
                     test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')],
                     test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')],
                     valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')],
                     valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')],
                     indtype=theano.config.floatX, inscale=255., maxsize=maxsize)

mnist = lambda maxsize=None: GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'),
                                          maxsize=maxsize)