Mercurial > ift6266
diff datasets/ftfile.py @ 163:4b28d7382dbf
Add inital implementation of datasets.
For the moment only nist_digits is defined.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Thu, 25 Feb 2010 18:40:01 -0500 |
parents | |
children | 954185d6002a |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/ftfile.py Thu Feb 25 18:40:01 2010 -0500 @@ -0,0 +1,156 @@ +from pylearn.io.filetensor import _read_header, _prod +import numpy +from dataset import DataSet +from dsetiter import DataIterator + +class FTFile(object): + def __init__(self, fname): + r""" + Tests: + >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') + """ + self.file = open(fname, 'rb') + self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) + self.size = self.dim[0] + + def skip(self, num): + r""" + Skips `num` items in the file. + + Tests: + >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') + >>> f.size + 58646 + >>> f.elsize + 4 + >>> f.file.tell() + 20 + >>> f.skip(1000) + >>> f.file.tell() + 4020 + >>> f.size + 57646 + """ + if num >= self.size: + self.size = 0 + else: + self.size -= num + f_start = self.file.tell() + self.file.seek(f_start + (self.elsize * _prod(self.dim[1:]) * num)) + + def read(self, num): + r""" + Reads `num` elements from the file and return the result as a + numpy matrix. Last read is truncated. + + Tests: + >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') + >>> f.read(1) + array([6], dtype=int32) + >>> f.read(10) + array([7, 4, 7, 5, 6, 4, 8, 0, 9, 6], dtype=int32) + >>> f.skip(58630) + >>> f.read(10) + array([9, 2, 4, 2, 8], dtype=int32) + >>> f.read(10) + array([], dtype=int32) + >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> f.read(1) + array([[0, 0, 0, ..., 0, 0, 0]], dtype=uint8) + """ + if num > self.size: + num = self.size + self.dim[0] = num + self.size -= num + return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) + +class FTSource(object): + def __init__(self, file, skip=0, size=None): + r""" + Create a data source from a possible subset of a .ft file. + + Parameters: + `file` (string) -- the filename + `skip` (int, optional) -- amount of examples to skip from the start of the file + `size` (int, optional) -- truncates number of examples read (after skipping) + + Tests: + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) + """ + self.file = file + self.skip = skip + self.size = size + + def open(self): + r""" + Returns an FTFile that corresponds to this dataset. + + Tests: + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> f = s.open() + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) + >>> len(s.open().read(2)) + 1 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) + >>> s.open().size + 1000 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) + >>> s.open().size + 1 + """ + f = FTFile(self.file) + if self.skip != 0: + f.skip(self.skip) + if self.size is not None and self.size < f.size: + f.size = self.size + return f + +class FTData(object): + r""" + This is a list of FTSources. + """ + def __init__(self, datafiles, labelfiles, skip=0, size=None): + self.inputs = [FTSource(f, skip, size) for f in datafiles] + self.outputs = [FTSource(f, skip, size) for f in labelfiles] + + def open_inputs(self): + return [f.open() for f in self.inputs] + + def open_outputs(self): + return [f.open() for f in self.outputs] + + +class FTDataSet(DataSet): + def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None): + r""" + Defines a DataSet from a bunch of files. + + Parameters: + `train_data` -- list of train data files + `train_label` -- list of train label files (same length as `train_data`) + `test_data`, `test_labels` -- same thing as train, but for + test. The number of files + can differ from train. + `valid_data`, `valid_labels` -- same thing again for validation. + (optional) + + If `valid_data` and `valid_labels` are not supplied then a sample + approximately equal in size to the test set is taken from the train + set. + """ + if valid_data is None: + total_valid_size = sum(FTFile(td).size for td in test_data) + valid_size = total_valid_size/len(train_data) + self._train = FTData(train_data, train_lbl, skip=valid_size) + self._valid = FTData(train_data, train_lbl, size=valid_size) + else: + self._train = FTData(train_data, train_lbl) + self._valid = FTData(valid_data, valid_lbl) + self._test = FTData(test_data, test_lbl) + + def _return_it(self, batchsize, bufsize, ftdata): + return zip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), + DataIterator(ftdata.open_outputs(), batchsize, bufsize))