comparison datasets/defs.py @ 246:2024368a8d3d

merge
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Tue, 16 Mar 2010 12:14:10 -0400
parents 6f4e3719a3cc
children 966272e7f14b
comparison
equal deleted inserted replaced
245:0de14b2034c6 246:2024368a8d3d
1 __all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr'] 1 __all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr',
2 'nist_P07', 'mnist']
2 3
3 from ftfile import FTDataSet 4 from ftfile import FTDataSet
5 from gzpklfile import GzpklDataSet
4 import theano 6 import theano
7 import os
5 8
6 NIST_PATH = '/data/lisa/data/nist/by_class/' 9 # if the environmental variables exist, get the path from them,
7 DATA_PATH = '/data/lisa/data/ift6266h10/' 10 # otherwise fall back on the default
11 NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/')
12 DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/')
8 13
9 nist_digits = FTDataSet(train_data = [NIST_PATH+'digits/digits_train_data.ft'], 14 nist_digits = FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
10 train_lbl = [NIST_PATH+'digits/digits_train_labels.ft'], 15 train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')],
11 test_data = [NIST_PATH+'digits/digits_test_data.ft'], 16 test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')],
12 test_lbl = [NIST_PATH+'digits/digits_test_labels.ft'], 17 test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')],
13 indtype=theano.config.floatX, inscale=255.) 18 indtype=theano.config.floatX, inscale=255.)
14 nist_lower = FTDataSet(train_data = [NIST_PATH+'lower/lower_train_data.ft'], 19 nist_lower = FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
15 train_lbl = [NIST_PATH+'lower/lower_train_labels.ft'], 20 train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')],
16 test_data = [NIST_PATH+'lower/lower_test_data.ft'], 21 test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')],
17 test_lbl = [NIST_PATH+'lower/lower_test_labels.ft'], 22 test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')],
18 indtype=theano.config.floatX, inscale=255.) 23 indtype=theano.config.floatX, inscale=255.)
19 nist_upper = FTDataSet(train_data = [NIST_PATH+'upper/upper_train_data.ft'], 24 nist_upper = FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
20 train_lbl = [NIST_PATH+'upper/upper_train_labels.ft'], 25 train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')],
21 test_data = [NIST_PATH+'upper/upper_test_data.ft'], 26 test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')],
22 test_lbl = [NIST_PATH+'upper/upper_test_labels.ft'], 27 test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')],
23 indtype=theano.config.floatX, inscale=255.) 28 indtype=theano.config.floatX, inscale=255.)
24 29
25 nist_all = FTDataSet(train_data = [DATA_PATH+'train_data.ft'], 30 nist_all = FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
26 train_lbl = [DATA_PATH+'train_labels.ft'], 31 train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')],
27 test_data = [DATA_PATH+'test_data.ft'], 32 test_data = [os.path.join(DATA_PATH,'test_data.ft')],
28 test_lbl = [DATA_PATH+'test_labels.ft'], 33 test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')],
29 valid_data = [DATA_PATH+'valid_data.ft'], 34 valid_data = [os.path.join(DATA_PATH,'valid_data.ft')],
30 valid_lbl = [DATA_PATH+'valid_labels.ft'], 35 valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')],
31 indtype=theano.config.floatX, inscale=255.) 36 indtype=theano.config.floatX, inscale=255.)
32 37
33 ocr = FTDataSet(train_data = [DATA_PATH+'ocr_train_data.ft'], 38 ocr = FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
34 train_lbl = [DATA_PATH+'ocr_train_labels.ft'], 39 train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')],
35 test_data = [DATA_PATH+'ocr_test_data.ft'], 40 test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')],
36 test_lbl = [DATA_PATH+'ocr_test_labels.ft'], 41 test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')],
37 valid_data = [DATA_PATH+'ocr_valid_data.ft'], 42 valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')],
38 valid_lbl = [DATA_PATH+'ocr_valid_labels.ft']) 43 valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')],
44 indtype=theano.config.floatX, inscale=255.)
45
46 nist_P07 = FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
47 train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)],
48 test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')],
49 test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')],
50 valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')],
51 valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')],
52 indtype=theano.config.floatX, inscale=255.)
53
54 mnist = GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'))