diff datasets/defs.py @ 231:6f4e3719a3cc

Added the possibility to get the paths from an env. variable + cleaned up the way we build the paths
author Dumitru Erhan <dumitru.erhan@gmail.com>
date Sat, 13 Mar 2010 15:44:50 -0500
parents 4cfd0eb438af
children 966272e7f14b
line wrap: on
line diff
--- a/datasets/defs.py	Fri Mar 12 10:47:36 2010 -0500
+++ b/datasets/defs.py	Sat Mar 13 15:44:50 2010 -0500
@@ -4,48 +4,51 @@
 from ftfile import FTDataSet
 from gzpklfile import GzpklDataSet
 import theano
-
-NIST_PATH = '/data/lisa/data/nist/by_class/'
-DATA_PATH = '/data/lisa/data/ift6266h10/'
+import os
 
-nist_digits = FTDataSet(train_data = [NIST_PATH+'digits/digits_train_data.ft'],
-                        train_lbl = [NIST_PATH+'digits/digits_train_labels.ft'],
-                        test_data = [NIST_PATH+'digits/digits_test_data.ft'],
-                        test_lbl = [NIST_PATH+'digits/digits_test_labels.ft'],
+# if the environmental variables exist, get the path from them, 
+# otherwise fall back on the default
+NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/')
+DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/')
+
+nist_digits = FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
+                        train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')],
+                        test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')],
+                        test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')],
                         indtype=theano.config.floatX, inscale=255.)
-nist_lower = FTDataSet(train_data = [NIST_PATH+'lower/lower_train_data.ft'],
-                        train_lbl = [NIST_PATH+'lower/lower_train_labels.ft'],
-                        test_data = [NIST_PATH+'lower/lower_test_data.ft'],
-                        test_lbl = [NIST_PATH+'lower/lower_test_labels.ft'],
+nist_lower = FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
+                        train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')],
+                        test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')],
+                        test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')],
                         indtype=theano.config.floatX, inscale=255.)
-nist_upper = FTDataSet(train_data = [NIST_PATH+'upper/upper_train_data.ft'],
-                        train_lbl = [NIST_PATH+'upper/upper_train_labels.ft'],
-                        test_data = [NIST_PATH+'upper/upper_test_data.ft'],
-                        test_lbl = [NIST_PATH+'upper/upper_test_labels.ft'],
+nist_upper = FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
+                        train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')],
+                        test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')],
+                        test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')],
                         indtype=theano.config.floatX, inscale=255.)
 
-nist_all = FTDataSet(train_data = [DATA_PATH+'train_data.ft'],
-                     train_lbl = [DATA_PATH+'train_labels.ft'],
-                     test_data = [DATA_PATH+'test_data.ft'],
-                     test_lbl = [DATA_PATH+'test_labels.ft'],
-                     valid_data = [DATA_PATH+'valid_data.ft'],
-                     valid_lbl = [DATA_PATH+'valid_labels.ft'],
+nist_all = FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
+                     train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')],
+                     test_data = [os.path.join(DATA_PATH,'test_data.ft')],
+                     test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')],
+                     valid_data = [os.path.join(DATA_PATH,'valid_data.ft')],
+                     valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')],
                      indtype=theano.config.floatX, inscale=255.)
 
-ocr = FTDataSet(train_data = [DATA_PATH+'ocr_train_data.ft'],
-                train_lbl = [DATA_PATH+'ocr_train_labels.ft'],
-                test_data = [DATA_PATH+'ocr_test_data.ft'],
-                test_lbl = [DATA_PATH+'ocr_test_labels.ft'],
-                valid_data = [DATA_PATH+'ocr_valid_data.ft'],
-                valid_lbl = [DATA_PATH+'ocr_valid_labels.ft'],
+ocr = FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
+                train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')],
+                test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')],
+                test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')],
+                valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')],
+                valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')],
                 indtype=theano.config.floatX, inscale=255.)
 
-nist_P07 = FTDataSet(train_data = [DATA_PATH+'data/P07_train'+str(i)+'_data.ft' for i in range(100)],
-                     train_lbl = [DATA_PATH+'data/P07_train'+str(i)+'_labels.ft' for i in range(100)],
-                     test_data = [DATA_PATH+'data/P07_test_data.ft'],
-                     test_lbl = [DATA_PATH+'data/P07_test_labels.ft'],
-                     valid_data = [DATA_PATH+'data/P07_valid_data.ft'],
-                     valid_lbl = [DATA_PATH+'data/P07_valid_labels.ft'],
+nist_P07 = FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
+                     train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)],
+                     test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')],
+                     test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')],
+                     valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')],
+                     valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')],
                      indtype=theano.config.floatX, inscale=255.)
 
-mnist = GzpklDataSet(DATA_PATH+'mnist.pkl.gz')
+mnist = GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'))