diff datasets/gzpklfile.py @ 222:4cfd0eb438af

Add mnist to datasets (and supporting code).
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 11 Mar 2010 14:41:31 -0500
parents
children 966272e7f14b
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/datasets/gzpklfile.py	Thu Mar 11 14:41:31 2010 -0500
@@ -0,0 +1,39 @@
+import gzip
+try:
+    import cPickle as pickle
+except ImportError:
+    import pickle
+
+from dataset import DataSet
+from dsetiter import DataIterator
+from itertools import izip
+
+class ArrayFile(object):
+    def __init__(self, ary):
+        self.ary = ary
+        self.pos = 0
+
+    def read(self, num):
+        res = self.ary[self.pos:self.pos+num]
+        self.pos += num
+        return res
+
+class GzpklDataSet(DataSet):
+    def __init__(self, fname):
+        self._fname = fname
+        self._train = 0
+        self._valid = 1
+        self._test = 2
+
+    def _load(self):
+        f = gzip.open(self._fname, 'rb')
+        try:
+            self.datas = pickle.load(f)
+        finally:
+            f.close()
+
+    def _return_it(self, batchsz, bufsz, id):
+        if not hasattr(self, 'datas'):
+            self._load()
+        return izip(DataIterator([ArrayFile(self.datas[id][0])], batchsz, bufsz),
+                    DataIterator([ArrayFile(self.datas[id][1])], batchsz, bufsz))