# HG changeset patch # User fsavard # Date 1268342979 18000 # Node ID eb78a695ad7ade7b12a951968167738cc80922de # Parent 0515a8901c6af568e7a37c0fae75fb7077ef988d# Parent 8547b0cbe4ff477d8475dcafa80bb7a3ec32372f Merge diff -r 0515a8901c6a -r eb78a695ad7a datasets/defs.py --- a/datasets/defs.py Thu Mar 11 11:52:43 2010 -0500 +++ b/datasets/defs.py Thu Mar 11 16:29:39 2010 -0500 @@ -1,7 +1,8 @@ __all__ = ['nist_digits', 'nist_lower', 'nist_upper', 'nist_all', 'ocr', - 'nist_P07'] + 'nist_P07', 'mnist'] from ftfile import FTDataSet +from gzpklfile import GzpklDataSet import theano NIST_PATH = '/data/lisa/data/nist/by_class/' @@ -46,3 +47,5 @@ valid_data = [DATA_PATH+'data/P07_valid_data.ft'], valid_lbl = [DATA_PATH+'data/P07_valid_labels.ft'], indtype=theano.config.floatX, inscale=255.) + +mnist = GzpklDataSet(DATA_PATH+'mnist.pkl.gz') diff -r 0515a8901c6a -r eb78a695ad7a datasets/gzpklfile.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/gzpklfile.py Thu Mar 11 16:29:39 2010 -0500 @@ -0,0 +1,39 @@ +import gzip +try: + import cPickle as pickle +except ImportError: + import pickle + +from dataset import DataSet +from dsetiter import DataIterator +from itertools import izip + +class ArrayFile(object): + def __init__(self, ary): + self.ary = ary + self.pos = 0 + + def read(self, num): + res = self.ary[self.pos:self.pos+num] + self.pos += num + return res + +class GzpklDataSet(DataSet): + def __init__(self, fname): + self._fname = fname + self._train = 0 + self._valid = 1 + self._test = 2 + + def _load(self): + f = gzip.open(self._fname, 'rb') + try: + self.datas = pickle.load(f) + finally: + f.close() + + def _return_it(self, batchsz, bufsz, id): + if not hasattr(self, 'datas'): + self._load() + return izip(DataIterator([ArrayFile(self.datas[id][0])], batchsz, bufsz), + DataIterator([ArrayFile(self.datas[id][1])], batchsz, bufsz))