Mercurial > ift6266
diff datasets/ftfile.py @ 262:716c99f4eb3a
merge
author | Xavier Glorot <glorotxa@iro.umontreal.ca> |
---|---|
date | Wed, 17 Mar 2010 16:41:51 -0400 |
parents | 966272e7f14b |
children | a92ec9939e4f |
line wrap: on
line diff
--- a/datasets/ftfile.py Wed Mar 17 16:41:16 2010 -0400 +++ b/datasets/ftfile.py Wed Mar 17 16:41:51 2010 -0400 @@ -89,57 +89,58 @@ return res class FTSource(object): - def __init__(self, file, skip=0, size=None, dtype=None, scale=1): + def __init__(self, file, skip=0, size=None, maxsize=None, + dtype=None, scale=1): r""" Create a data source from a possible subset of a .ft file. Parameters: - `file` (string) -- the filename - `skip` (int, optional) -- amount of examples to skip from - the start of the file. If - negative, skips filesize - skip. - `size` (int, optional) -- truncates number of examples - read (after skipping). If - negative truncates to - filesize - size - (also after skipping). - `dtype` (dtype, optional) -- convert the data to this - dtype after reading. - `scale` (number, optional) -- scale (that is divide) the - data by this number (after - dtype conversion, if any). + `file` -- (string) the filename + `skip` -- (int, optional) amount of examples to skip from + the start of the file. If negative, skips + filesize - skip. + `size` -- (int, optional) truncates number of examples + read (after skipping). If negative truncates to + filesize - size (also after skipping). + `maxsize` -- (int, optional) the maximum size of the file + `dtype` -- (dtype, optional) convert the data to this + dtype after reading. + `scale` -- (number, optional) scale (that is divide) the + data by this number (after dtype conversion, if + any). Tests: - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) """ self.file = file self.skip = skip self.size = size self.dtype = dtype self.scale = scale + self.maxsize = maxsize def open(self): r""" Returns an FTFile that corresponds to this dataset. Tests: - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') - >>> f = s.open() - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) - >>> len(s.open().read(2)) - 1 - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) - >>> s.open().size - 1000 - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) - >>> s.open().size - 1 - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) - >>> s.open().size - 58636 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> f = s.open() + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) + >>> len(s.open().read(2)) + 1 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) + >>> s.open().size + 1000 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) + >>> s.open().size + 1 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) + >>> s.open().size + 58636 """ f = FTFile(self.file, scale=self.scale, dtype=self.dtype) if self.skip != 0: @@ -147,19 +148,25 @@ if self.size is not None and self.size < f.size: if self.size < 0: f.size += self.size + if f.size < 0: + f.size = 0 else: f.size = self.size + if self.maxsize is not None and f.size > self.maxsize: + f.size = self.maxsize return f class FTData(object): r""" This is a list of FTSources. """ - def __init__(self, datafiles, labelfiles, skip=0, size=None, + def __init__(self, datafiles, labelfiles, skip=0, size=None, maxsize=None, inscale=1, indtype=None, outscale=1, outdtype=None): - self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype) + if maxsize is not None: + maxsize /= len(datafiles) + self.inputs = [FTSource(f, skip, size, maxsize, scale=inscale, dtype=indtype) for f in datafiles] - self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype) + self.outputs = [FTSource(f, skip, size, maxsize, scale=outscale, dtype=outdtype) for f in labelfiles] def open_inputs(self): @@ -170,7 +177,9 @@ class FTDataSet(DataSet): - def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1): + def __init__(self, train_data, train_lbl, test_data, test_lbl, + valid_data=None, valid_lbl=None, indtype=None, outdtype=None, + inscale=1, outscale=1, maxsize=None): r""" Defines a DataSet from a bunch of files. @@ -184,6 +193,7 @@ (optional) `indtype`, `outdtype`, -- see FTSource.__init__() `inscale`, `outscale` (optional) + `maxsize` -- maximum size of the set returned If `valid_data` and `valid_labels` are not supplied then a sample @@ -191,21 +201,26 @@ set. """ if valid_data is None: - total_valid_size = sum(FTFile(td).size for td in test_data) + total_valid_size = min(sum(FTFile(td).size for td in test_data), maxsize) valid_size = total_valid_size/len(train_data) self._train = FTData(train_data, train_lbl, size=-valid_size, - inscale=inscale, outscale=outscale, indtype=indtype, - outdtype=outdtype) + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype, + maxsize=maxsize) self._valid = FTData(train_data, train_lbl, skip=-valid_size, - inscale=inscale, outscale=outscale, indtype=indtype, - outdtype=outdtype) + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype, + maxsize=maxsize) else: - self._train = FTData(train_data, train_lbl,inscale=inscale, - outscale=outscale, indtype=indtype, outdtype=outdtype) - self._valid = FTData(valid_data, valid_lbl,inscale=inscale, - outscale=outscale, indtype=indtype, outdtype=outdtype) - self._test = FTData(test_data, test_lbl,inscale=inscale, - outscale=outscale, indtype=indtype, outdtype=outdtype) + self._train = FTData(train_data, train_lbl, maxsize=maxsize, + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype) + self._valid = FTData(valid_data, valid_lbl, maxsize=maxsize, + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype) + self._test = FTData(test_data, test_lbl, maxsize=maxsize, + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype) def _return_it(self, batchsize, bufsize, ftdata): return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),