changeset 349:22efb4968054

added pnist support, will check in code for data set iterator later
author xaviermuller
date Mon, 19 Apr 2010 10:12:17 -0400
parents 45156cbf6722
children 625c0c3fcbdb
files baseline/mlp/mlp_nist.py datasets/defs.py
diffstat 2 files changed, 12 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/baseline/mlp/mlp_nist.py	Mon Apr 19 08:17:45 2010 -0400
+++ b/baseline/mlp/mlp_nist.py	Mon Apr 19 10:12:17 2010 -0400
@@ -268,6 +268,8 @@
     	dataset=datasets.nist_all()
     elif data_set==1:
         dataset=datasets.nist_P07()
+    elif data_set==2:
+        dataset=datasets.PNIST07()
     
     
     
--- a/datasets/defs.py	Mon Apr 19 08:17:45 2010 -0400
+++ b/datasets/defs.py	Mon Apr 19 10:12:17 2010 -0400
@@ -1,5 +1,5 @@
 __all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr', 
-           'nist_P07', 'mnist']
+           'nist_P07', 'PNIST07', 'mnist']
 
 from ftfile import FTDataSet
 from gzpklfile import GzpklDataSet
@@ -52,6 +52,15 @@
                      valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')],
                      valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')],
                      indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
+		     
+#Added PNIST07
+PNIST07 = lambda maxsize=None, min_file=0, max_file=100: FTDataSet(train_data = [os.path.join(DATA_PATH,'data/PNIST07_train'+str(i)+'_data.ft') for i in range(min_file, max_file)],
+                     train_lbl = [os.path.join(DATA_PATH,'data/PNIST07_train'+str(i)+'_labels.ft') for i in range(min_file, max_file)],
+                     test_data = [os.path.join(DATA_PATH,'data/PNIST07_test_data.ft')],
+                     test_lbl = [os.path.join(DATA_PATH,'data/PNIST07_test_labels.ft')],
+                     valid_data = [os.path.join(DATA_PATH,'data/PNIST07_valid_data.ft')],
+                     valid_lbl = [os.path.join(DATA_PATH,'data/PNIST07_valid_labels.ft')],
+                     indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
 
 mnist = lambda maxsize=None: GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'),
                                           maxsize=maxsize)