Mercurial > ift6266
view datasets/ftfile.py @ 163:4b28d7382dbf
Add inital implementation of datasets.
For the moment only nist_digits is defined.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Thu, 25 Feb 2010 18:40:01 -0500 |
parents | |
children | 954185d6002a |
line wrap: on
line source
from pylearn.io.filetensor import _read_header, _prod import numpy from dataset import DataSet from dsetiter import DataIterator class FTFile(object): def __init__(self, fname): r""" Tests: >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') """ self.file = open(fname, 'rb') self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) self.size = self.dim[0] def skip(self, num): r""" Skips `num` items in the file. Tests: >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') >>> f.size 58646 >>> f.elsize 4 >>> f.file.tell() 20 >>> f.skip(1000) >>> f.file.tell() 4020 >>> f.size 57646 """ if num >= self.size: self.size = 0 else: self.size -= num f_start = self.file.tell() self.file.seek(f_start + (self.elsize * _prod(self.dim[1:]) * num)) def read(self, num): r""" Reads `num` elements from the file and return the result as a numpy matrix. Last read is truncated. Tests: >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') >>> f.read(1) array([6], dtype=int32) >>> f.read(10) array([7, 4, 7, 5, 6, 4, 8, 0, 9, 6], dtype=int32) >>> f.skip(58630) >>> f.read(10) array([9, 2, 4, 2, 8], dtype=int32) >>> f.read(10) array([], dtype=int32) >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') >>> f.read(1) array([[0, 0, 0, ..., 0, 0, 0]], dtype=uint8) """ if num > self.size: num = self.size self.dim[0] = num self.size -= num return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) class FTSource(object): def __init__(self, file, skip=0, size=None): r""" Create a data source from a possible subset of a .ft file. Parameters: `file` (string) -- the filename `skip` (int, optional) -- amount of examples to skip from the start of the file `size` (int, optional) -- truncates number of examples read (after skipping) Tests: >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) """ self.file = file self.skip = skip self.size = size def open(self): r""" Returns an FTFile that corresponds to this dataset. Tests: >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') >>> f = s.open() >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) >>> len(s.open().read(2)) 1 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) >>> s.open().size 1000 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) >>> s.open().size 1 """ f = FTFile(self.file) if self.skip != 0: f.skip(self.skip) if self.size is not None and self.size < f.size: f.size = self.size return f class FTData(object): r""" This is a list of FTSources. """ def __init__(self, datafiles, labelfiles, skip=0, size=None): self.inputs = [FTSource(f, skip, size) for f in datafiles] self.outputs = [FTSource(f, skip, size) for f in labelfiles] def open_inputs(self): return [f.open() for f in self.inputs] def open_outputs(self): return [f.open() for f in self.outputs] class FTDataSet(DataSet): def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None): r""" Defines a DataSet from a bunch of files. Parameters: `train_data` -- list of train data files `train_label` -- list of train label files (same length as `train_data`) `test_data`, `test_labels` -- same thing as train, but for test. The number of files can differ from train. `valid_data`, `valid_labels` -- same thing again for validation. (optional) If `valid_data` and `valid_labels` are not supplied then a sample approximately equal in size to the test set is taken from the train set. """ if valid_data is None: total_valid_size = sum(FTFile(td).size for td in test_data) valid_size = total_valid_size/len(train_data) self._train = FTData(train_data, train_lbl, skip=valid_size) self._valid = FTData(train_data, train_lbl, size=valid_size) else: self._train = FTData(train_data, train_lbl) self._valid = FTData(valid_data, valid_lbl) self._test = FTData(test_data, test_lbl) def _return_it(self, batchsize, bufsize, ftdata): return zip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), DataIterator(ftdata.open_outputs(), batchsize, bufsize))