Mercurial > ift6266
changeset 163:4b28d7382dbf
Add inital implementation of datasets.
For the moment only nist_digits is defined.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Thu, 25 Feb 2010 18:40:01 -0500 |
parents | 050c7ff6b449 |
children | e3de934a98b6 |
files | datasets/__init__.py datasets/dataset.py datasets/dsetiter.py datasets/ftfile.py datasets/nist.py |
diffstat | 5 files changed, 382 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/__init__.py Thu Feb 25 18:40:01 2010 -0500 @@ -0,0 +1,1 @@ +from nist import *
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/dataset.py Thu Feb 25 18:40:01 2010 -0500 @@ -0,0 +1,46 @@ +from dsetiter import DataIterator + +class DataSet(object): + def test(self, batchsize, bufsize=None): + r""" + Returns an iterator over the test examples. + + Parameters + batchsize (int) -- the size of the minibatches, 0 means + return the whole set at once. + bufsize (int, optional) -- the size of the in-memory buffer, + 0 to disable. + """ + return self._return_it(batchsize, bufsize, self._test) + + def train(self, batchsize, bufsize=None): + r""" + Returns an iterator over the training examples. + + Parameters + batchsize (int) -- the size of the minibatches, 0 means + return the whole set at once. + bufsize (int, optional) -- the size of the in-memory buffer, + 0 to disable. + """ + return self._return_it(batchsize, bufsize, self._train) + + def valid(self, batchsize, bufsize=None): + r""" + Returns an iterator over the validation examples. + + Parameters + batchsize (int) -- the size of the minibatches, 0 means + return the whole set at once. + bufsize (int, optional) -- the size of the in-memory buffer, + 0 to disable. + """ + return self._return_it(batchsize, bufsize, self._valid) + + def _return_it(batchsize, bufsize, data): + r""" + Must return an iterator over the specified dataset (`data`). + + Implement this in subclassses. + """ + raise NotImplemented
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/dsetiter.py Thu Feb 25 18:40:01 2010 -0500 @@ -0,0 +1,156 @@ +import numpy + +class DummyFile(object): + def __init__(self, size): + self.size = size + + def read(self, num): + if num > self.size: + num = self.size + self.size -= num + return numpy.zeros((num, 3, 2)) + +class DataIterator(object): + + def __init__(self, files, batchsize, bufsize=None): + r""" + Makes an iterator which will read examples from `files` + and return them in `batchsize` lots. + + Parameters: + files -- list of numpy readers + batchsize -- (int) the size of returned batches + bufsize -- (int, default=None) internal read buffer size. + + Tests: + >>> d = DataIterator([DummyFile(930)], 10, 100) + >>> d.batchsize + 10 + >>> d.bufsize + 100 + >>> d = DataIterator([DummyFile(1)], 10) + >>> d.batchsize + 10 + >>> d.bufsize + 10000 + >>> d = DataIterator([DummyFile(1)], 99) + >>> d.batchsize + 99 + >>> d.bufsize + 9999 + >>> d = DataIterator([DummyFile(1)], 10, 121) + >>> d.batchsize + 10 + >>> d.bufsize + 120 + >>> d = DataIterator([DummyFile(1)], 10, 1) + >>> d.batchsize + 10 + >>> d.bufsize + 10 + >>> d = DataIterator([DummyFile(1)], 2000) + >>> d.batchsize + 2000 + >>> d.bufsize + 20000 + >>> d = DataIterator([DummyFile(1)], 2000, 31254) + >>> d.batchsize + 2000 + >>> d.bufsize + 30000 + >>> d = DataIterator([DummyFile(1)], 2000, 10) + >>> d.batchsize + 2000 + >>> d.bufsize + 2000 + """ + self.batchsize = batchsize + if bufsize is None: + self.bufsize = max(10*batchsize, 10000) + else: + self.bufsize = bufsize + self.bufsize -= self.bufsize % self.batchsize + if self.bufsize < self.batchsize: + self.bufsize = self.batchsize + self.files = iter(files) + self.curfile = self.files.next() + self.empty = False + self._fill_buf() + + def _fill_buf(self): + r""" + Fill the internal buffer. + + Will fill across files in case the current one runs out. + + Test: + >>> d = DataIterator([DummyFile(20)], 10, 10) + >>> d._fill_buf() + >>> d.curpos + 0 + >>> len(d.buffer) + 10 + >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10) + >>> d._fill_buf() + >>> len(d.buffer) + 10 + >>> d._fill_buf() + Traceback (most recent call last): + ... + StopIteration + >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10) + >>> d._fill_buf() + >>> len(d.buffer) + 9 + >>> d._fill_buf() + Traceback (most recent call last): + ... + StopIteration + """ + if self.empty: + raise StopIteration + self.buffer = self.curfile.read(self.bufsize) + + while len(self.buffer) < self.bufsize: + try: + self.curfile = self.files.next() + except StopIteration: + self.empty = True + if len(self.buffer) == 0: + raise StopIteration + self.curpos = 0 + return + tmpbuf = self.curfile.read(self.bufsize - len(self.buffer)) + self.buffer = numpy.row_stack((self.buffer, tmpbuf)) + self.curpos = 0 + + def __next__(self): + r""" + Returns the next portion of the dataset. + + Test: + >>> d = DataIterator([DummyFile(20)], 10, 20) + >>> len(d.next()) + 10 + >>> len(d.next()) + 10 + >>> d.next() + Traceback (most recent call last): + ... + StopIteration + >>> d.next() + Traceback (most recent call last): + ... + StopIteration + + """ + if self.curpos >= self.bufsize: + self._fill_buf() + res = self.buffer[self.curpos:self.curpos+self.batchsize] + self.curpos += self.batchsize + return res + + next = __next__ + + def __iter__(self): + return self
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/ftfile.py Thu Feb 25 18:40:01 2010 -0500 @@ -0,0 +1,156 @@ +from pylearn.io.filetensor import _read_header, _prod +import numpy +from dataset import DataSet +from dsetiter import DataIterator + +class FTFile(object): + def __init__(self, fname): + r""" + Tests: + >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') + """ + self.file = open(fname, 'rb') + self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) + self.size = self.dim[0] + + def skip(self, num): + r""" + Skips `num` items in the file. + + Tests: + >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') + >>> f.size + 58646 + >>> f.elsize + 4 + >>> f.file.tell() + 20 + >>> f.skip(1000) + >>> f.file.tell() + 4020 + >>> f.size + 57646 + """ + if num >= self.size: + self.size = 0 + else: + self.size -= num + f_start = self.file.tell() + self.file.seek(f_start + (self.elsize * _prod(self.dim[1:]) * num)) + + def read(self, num): + r""" + Reads `num` elements from the file and return the result as a + numpy matrix. Last read is truncated. + + Tests: + >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') + >>> f.read(1) + array([6], dtype=int32) + >>> f.read(10) + array([7, 4, 7, 5, 6, 4, 8, 0, 9, 6], dtype=int32) + >>> f.skip(58630) + >>> f.read(10) + array([9, 2, 4, 2, 8], dtype=int32) + >>> f.read(10) + array([], dtype=int32) + >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> f.read(1) + array([[0, 0, 0, ..., 0, 0, 0]], dtype=uint8) + """ + if num > self.size: + num = self.size + self.dim[0] = num + self.size -= num + return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) + +class FTSource(object): + def __init__(self, file, skip=0, size=None): + r""" + Create a data source from a possible subset of a .ft file. + + Parameters: + `file` (string) -- the filename + `skip` (int, optional) -- amount of examples to skip from the start of the file + `size` (int, optional) -- truncates number of examples read (after skipping) + + Tests: + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) + """ + self.file = file + self.skip = skip + self.size = size + + def open(self): + r""" + Returns an FTFile that corresponds to this dataset. + + Tests: + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> f = s.open() + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) + >>> len(s.open().read(2)) + 1 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) + >>> s.open().size + 1000 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) + >>> s.open().size + 1 + """ + f = FTFile(self.file) + if self.skip != 0: + f.skip(self.skip) + if self.size is not None and self.size < f.size: + f.size = self.size + return f + +class FTData(object): + r""" + This is a list of FTSources. + """ + def __init__(self, datafiles, labelfiles, skip=0, size=None): + self.inputs = [FTSource(f, skip, size) for f in datafiles] + self.outputs = [FTSource(f, skip, size) for f in labelfiles] + + def open_inputs(self): + return [f.open() for f in self.inputs] + + def open_outputs(self): + return [f.open() for f in self.outputs] + + +class FTDataSet(DataSet): + def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None): + r""" + Defines a DataSet from a bunch of files. + + Parameters: + `train_data` -- list of train data files + `train_label` -- list of train label files (same length as `train_data`) + `test_data`, `test_labels` -- same thing as train, but for + test. The number of files + can differ from train. + `valid_data`, `valid_labels` -- same thing again for validation. + (optional) + + If `valid_data` and `valid_labels` are not supplied then a sample + approximately equal in size to the test set is taken from the train + set. + """ + if valid_data is None: + total_valid_size = sum(FTFile(td).size for td in test_data) + valid_size = total_valid_size/len(train_data) + self._train = FTData(train_data, train_lbl, skip=valid_size) + self._valid = FTData(train_data, train_lbl, size=valid_size) + else: + self._train = FTData(train_data, train_lbl) + self._valid = FTData(valid_data, valid_lbl) + self._test = FTData(test_data, test_lbl) + + def _return_it(self, batchsize, bufsize, ftdata): + return zip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), + DataIterator(ftdata.open_outputs(), batchsize, bufsize))
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/nist.py Thu Feb 25 18:40:01 2010 -0500 @@ -0,0 +1,23 @@ +__all__ = ['nist_digits'] + +from ftfile import FTDataSet + +PATH = '/data/lisa/data/nist/by_class/' + +nist_digits = FTDataSet(train_data = [PATH+'digits/digits_train_data.ft'], + train_lbl = [PATH+'digits/digits_train_labels.ft'], + test_data = [PATH+'digits/digits_test_data.ft'], + test_lbl = [PATH+'digits/digits_test_labels.ft']) +nist_lower = FTDataSet(train_data = [PATH+'lower/lower_train_data.ft'], + train_lbl = [PATH+'lower/lower_train_labels.ft'], + test_data = [PATH+'lower/lower_test_data.ft'], + test_lbl = [PATH+'lower/lower_test_labels.ft']) +nist_upper = FTDataSet(train_data = [PATH+'upper/upper_train_data.ft'], + train_lbl = [PATH+'upper/upper_train_labels.ft'], + test_data = [PATH+'upper/upper_test_data.ft'], + test_lbl = [PATH+'upper/upper_test_labels.ft']) +nist_all = FTDataSet(train_data = [PATH+'all/all_train_data.ft'], + train_lbl = [PATH+'all/all_train_labels.ft'], + test_data = [PATH+'all/all_test_data.ft'], + test_lbl = [PATH+'all/all_test_labels.ft']) +