Mercurial > ift6266
diff datasets/ftfile.py @ 180:76bc047df5ee
Add dtype conversion and rescaling to the read path.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Sat, 27 Feb 2010 16:50:16 -0500 |
parents | 938bd350dbf0 |
children | f0f47b045cbf |
line wrap: on
line diff
--- a/datasets/ftfile.py Sat Feb 27 16:07:09 2010 -0500 +++ b/datasets/ftfile.py Sat Feb 27 16:50:16 2010 -0500 @@ -5,7 +5,7 @@ from itertools import izip, imap class FTFile(object): - def __init__(self, fname): + def __init__(self, fname, scale=1, dtype=None): r""" Tests: >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') @@ -13,6 +13,8 @@ self.file = open(fname, 'rb') self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) self.size = self.dim[0] + self.scale = scale + self.dtype = dtype def skip(self, num): r""" @@ -79,10 +81,15 @@ num = self.size self.dim[0] = num self.size -= num - return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) + res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) + if self.dtype is not None: + res = res.astype(self.dtype) + if self.scale != 1: + res /= self.scale + return res class FTSource(object): - def __init__(self, file, skip=0, size=None): + def __init__(self, file, skip=0, size=None, dtype=None, scale=1): r""" Create a data source from a possible subset of a .ft file. @@ -96,7 +103,12 @@ negative truncates to filesize - size (also after skipping). - + `dtype` (dtype, optional) -- convert the data to this + dtype after reading. + `scale` (number, optional) -- scale (that is divide) the + data by this number (after + dtype conversion, if any). + Tests: >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) @@ -106,6 +118,8 @@ self.file = file self.skip = skip self.size = size + self.dtype = dtype + self.scale = scale def open(self): r""" @@ -127,7 +141,7 @@ >>> s.open().size 58636 """ - f = FTFile(self.file) + f = FTFile(self.file, scale=self.scale, dtype=self.dtype) if self.skip != 0: f.skip(self.skip) if self.size is not None and self.size < f.size: @@ -141,9 +155,12 @@ r""" This is a list of FTSources. """ - def __init__(self, datafiles, labelfiles, skip=0, size=None): - self.inputs = [FTSource(f, skip, size) for f in datafiles] - self.outputs = [FTSource(f, skip, size) for f in labelfiles] + def __init__(self, datafiles, labelfiles, skip=0, size=None, + inscale=1, indtype=None, outscale=1, outdtype=None): + self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype) + for f in datafiles] + self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype) + for f in labelfiles] def open_inputs(self): return [f.open() for f in self.inputs] @@ -153,7 +170,7 @@ class FTDataSet(DataSet): - def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None): + def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1): r""" Defines a DataSet from a bunch of files. @@ -165,6 +182,9 @@ can differ from train. `valid_data`, `valid_labels` -- same thing again for validation. (optional) + `indtype`, `outdtype`, -- see FTSource.__init__() + `inscale`, `outscale` (optional) + If `valid_data` and `valid_labels` are not supplied then a sample approximately equal in size to the test set is taken from the train