diff datasets/defs.py @ 246:2024368a8d3d

merge
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Tue, 16 Mar 2010 12:14:10 -0400
parents 6f4e3719a3cc
children 966272e7f14b
line wrap: on
line diff
--- a/datasets/defs.py	Tue Mar 16 12:13:49 2010 -0400
+++ b/datasets/defs.py	Tue Mar 16 12:14:10 2010 -0400
@@ -1,38 +1,54 @@
-__all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr']
+__all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr', 
+           'nist_P07', 'mnist']
 
 from ftfile import FTDataSet
+from gzpklfile import GzpklDataSet
 import theano
-
-NIST_PATH = '/data/lisa/data/nist/by_class/'
-DATA_PATH = '/data/lisa/data/ift6266h10/'
+import os
 
-nist_digits = FTDataSet(train_data = [NIST_PATH+'digits/digits_train_data.ft'],
-                        train_lbl = [NIST_PATH+'digits/digits_train_labels.ft'],
-                        test_data = [NIST_PATH+'digits/digits_test_data.ft'],
-                        test_lbl = [NIST_PATH+'digits/digits_test_labels.ft'],
+# if the environmental variables exist, get the path from them, 
+# otherwise fall back on the default
+NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/')
+DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/')
+
+nist_digits = FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
+                        train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')],
+                        test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')],
+                        test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')],
                         indtype=theano.config.floatX, inscale=255.)
-nist_lower = FTDataSet(train_data = [NIST_PATH+'lower/lower_train_data.ft'],
-                        train_lbl = [NIST_PATH+'lower/lower_train_labels.ft'],
-                        test_data = [NIST_PATH+'lower/lower_test_data.ft'],
-                        test_lbl = [NIST_PATH+'lower/lower_test_labels.ft'],
+nist_lower = FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
+                        train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')],
+                        test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')],
+                        test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')],
                         indtype=theano.config.floatX, inscale=255.)
-nist_upper = FTDataSet(train_data = [NIST_PATH+'upper/upper_train_data.ft'],
-                        train_lbl = [NIST_PATH+'upper/upper_train_labels.ft'],
-                        test_data = [NIST_PATH+'upper/upper_test_data.ft'],
-                        test_lbl = [NIST_PATH+'upper/upper_test_labels.ft'],
+nist_upper = FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
+                        train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')],
+                        test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')],
+                        test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')],
                         indtype=theano.config.floatX, inscale=255.)
 
-nist_all = FTDataSet(train_data = [DATA_PATH+'train_data.ft'],
-                     train_lbl = [DATA_PATH+'train_labels.ft'],
-                     test_data = [DATA_PATH+'test_data.ft'],
-                     test_lbl = [DATA_PATH+'test_labels.ft'],
-                     valid_data = [DATA_PATH+'valid_data.ft'],
-                     valid_lbl = [DATA_PATH+'valid_labels.ft'],
+nist_all = FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
+                     train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')],
+                     test_data = [os.path.join(DATA_PATH,'test_data.ft')],
+                     test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')],
+                     valid_data = [os.path.join(DATA_PATH,'valid_data.ft')],
+                     valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')],
                      indtype=theano.config.floatX, inscale=255.)
 
-ocr = FTDataSet(train_data = [DATA_PATH+'ocr_train_data.ft'],
-                train_lbl = [DATA_PATH+'ocr_train_labels.ft'],
-                test_data = [DATA_PATH+'ocr_test_data.ft'],
-                test_lbl = [DATA_PATH+'ocr_test_labels.ft'],
-                valid_data = [DATA_PATH+'ocr_valid_data.ft'],
-                valid_lbl = [DATA_PATH+'ocr_valid_labels.ft'])
+ocr = FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
+                train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')],
+                test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')],
+                test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')],
+                valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')],
+                valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')],
+                indtype=theano.config.floatX, inscale=255.)
+
+nist_P07 = FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
+                     train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)],
+                     test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')],
+                     test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')],
+                     valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')],
+                     valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')],
+                     indtype=theano.config.floatX, inscale=255.)
+
+mnist = GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'))