changeset 222:4cfd0eb438af

Add mnist to datasets (and supporting code).
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 11 Mar 2010 14:41:31 -0500
parents de3aef84714a
children 8547b0cbe4ff
files datasets/defs.py datasets/gzpklfile.py
diffstat 2 files changed, 43 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/datasets/defs.py	Wed Mar 10 17:08:50 2010 -0500
+++ b/datasets/defs.py	Thu Mar 11 14:41:31 2010 -0500
@@ -1,7 +1,8 @@
 __all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr', 
-           'nist_P07']
+           'nist_P07', 'mnist']
 
 from ftfile import FTDataSet
+from gzpklfile import GzpklDataSet
 import theano
 
 NIST_PATH = '/data/lisa/data/nist/by_class/'
@@ -46,3 +47,5 @@
                      valid_data = [DATA_PATH+'data/P07_valid_data.ft'],
                      valid_lbl = [DATA_PATH+'data/P07_valid_labels.ft'],
                      indtype=theano.config.floatX, inscale=255.)
+
+mnist = GzpklDataSet(DATA_PATH+'mnist.pkl.gz')
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/datasets/gzpklfile.py	Thu Mar 11 14:41:31 2010 -0500
@@ -0,0 +1,39 @@
+import gzip
+try:
+    import cPickle as pickle
+except ImportError:
+    import pickle
+
+from dataset import DataSet
+from dsetiter import DataIterator
+from itertools import izip
+
+class ArrayFile(object):
+    def __init__(self, ary):
+        self.ary = ary
+        self.pos = 0
+
+    def read(self, num):
+        res = self.ary[self.pos:self.pos+num]
+        self.pos += num
+        return res
+
+class GzpklDataSet(DataSet):
+    def __init__(self, fname):
+        self._fname = fname
+        self._train = 0
+        self._valid = 1
+        self._test = 2
+
+    def _load(self):
+        f = gzip.open(self._fname, 'rb')
+        try:
+            self.datas = pickle.load(f)
+        finally:
+            f.close()
+
+    def _return_it(self, batchsz, bufsz, id):
+        if not hasattr(self, 'datas'):
+            self._load()
+        return izip(DataIterator([ArrayFile(self.datas[id][0])], batchsz, bufsz),
+                    DataIterator([ArrayFile(self.datas[id][1])], batchsz, bufsz))