Mercurial > ift6266
comparison datasets/gzpklfile.py @ 222:4cfd0eb438af
Add mnist to datasets (and supporting code).
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Thu, 11 Mar 2010 14:41:31 -0500 |
parents | |
children | 966272e7f14b |
comparison
equal
deleted
inserted
replaced
217:de3aef84714a | 222:4cfd0eb438af |
---|---|
1 import gzip | |
2 try: | |
3 import cPickle as pickle | |
4 except ImportError: | |
5 import pickle | |
6 | |
7 from dataset import DataSet | |
8 from dsetiter import DataIterator | |
9 from itertools import izip | |
10 | |
11 class ArrayFile(object): | |
12 def __init__(self, ary): | |
13 self.ary = ary | |
14 self.pos = 0 | |
15 | |
16 def read(self, num): | |
17 res = self.ary[self.pos:self.pos+num] | |
18 self.pos += num | |
19 return res | |
20 | |
21 class GzpklDataSet(DataSet): | |
22 def __init__(self, fname): | |
23 self._fname = fname | |
24 self._train = 0 | |
25 self._valid = 1 | |
26 self._test = 2 | |
27 | |
28 def _load(self): | |
29 f = gzip.open(self._fname, 'rb') | |
30 try: | |
31 self.datas = pickle.load(f) | |
32 finally: | |
33 f.close() | |
34 | |
35 def _return_it(self, batchsz, bufsz, id): | |
36 if not hasattr(self, 'datas'): | |
37 self._load() | |
38 return izip(DataIterator([ArrayFile(self.datas[id][0])], batchsz, bufsz), | |
39 DataIterator([ArrayFile(self.datas[id][1])], batchsz, bufsz)) |