comparison datasets/defs.py @ 257:966272e7f14b

Make the datasets lazy-loading and add a maxsize parameter.
author Arnaud Bergeron <abergeron@gmail.com>
date Tue, 16 Mar 2010 18:51:27 -0400
parents 6f4e3719a3cc
children 4533350d7361
comparison
equal deleted inserted replaced
248:7e6fecabb656 257:966272e7f14b
9 # if the environmental variables exist, get the path from them, 9 # if the environmental variables exist, get the path from them,
10 # otherwise fall back on the default 10 # otherwise fall back on the default
11 NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/') 11 NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/')
12 DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/') 12 DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/')
13 13
14 nist_digits = FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')], 14 nist_digits = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
15 train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')], 15 train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')],
16 test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')], 16 test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')],
17 test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')], 17 test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')],
18 indtype=theano.config.floatX, inscale=255.) 18 indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
19 nist_lower = FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')], 19 nist_lower = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
20 train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')], 20 train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')],
21 test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')], 21 test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')],
22 test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')], 22 test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')],
23 indtype=theano.config.floatX, inscale=255.) 23 indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
24 nist_upper = FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')], 24 nist_upper = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
25 train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')], 25 train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')],
26 test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')], 26 test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')],
27 test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')], 27 test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')],
28 indtype=theano.config.floatX, inscale=255.) 28 indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
29 29
30 nist_all = FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')], 30 nist_all = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
31 train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')], 31 train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')],
32 test_data = [os.path.join(DATA_PATH,'test_data.ft')], 32 test_data = [os.path.join(DATA_PATH,'test_data.ft')],
33 test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')], 33 test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')],
34 valid_data = [os.path.join(DATA_PATH,'valid_data.ft')], 34 valid_data = [os.path.join(DATA_PATH,'valid_data.ft')],
35 valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')], 35 valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')],
36 indtype=theano.config.floatX, inscale=255.) 36 indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
37 37
38 ocr = FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')], 38 ocr = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
39 train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')], 39 train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')],
40 test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')], 40 test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')],
41 test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')], 41 test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')],
42 valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')], 42 valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')],
43 valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')], 43 valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')],
44 indtype=theano.config.floatX, inscale=255.) 44 indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
45 45
46 nist_P07 = FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)], 46 nist_P07 = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
47 train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)], 47 train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)],
48 test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')], 48 test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')],
49 test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')], 49 test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')],
50 valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')], 50 valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')],
51 valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')], 51 valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')],
52 indtype=theano.config.floatX, inscale=255.) 52 indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
53 53
54 mnist = GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz')) 54 mnist = lambda maxsize=None: GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'),
55 maxsize=maxsize)