Mercurial > ift6266
changeset 257:966272e7f14b
Make the datasets lazy-loading and add a maxsize parameter.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Tue, 16 Mar 2010 18:51:27 -0400 |
parents | 7e6fecabb656 |
children | c2fae7b96769 |
files | datasets/defs.py datasets/ftfile.py datasets/gzpklfile.py |
diffstat | 3 files changed, 81 insertions(+), 64 deletions(-) [+] |
line wrap: on
line diff
--- a/datasets/defs.py Tue Mar 16 14:46:25 2010 -0400 +++ b/datasets/defs.py Tue Mar 16 18:51:27 2010 -0400 @@ -11,44 +11,45 @@ NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/') DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/') -nist_digits = FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')], +nist_digits = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')], train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')], test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')], test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')], - indtype=theano.config.floatX, inscale=255.) -nist_lower = FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')], + indtype=theano.config.floatX, inscale=255., maxsize=maxsize) +nist_lower = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')], train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')], test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')], test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')], - indtype=theano.config.floatX, inscale=255.) -nist_upper = FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')], + indtype=theano.config.floatX, inscale=255., maxsize=maxsize) +nist_upper = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')], train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')], test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')], test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')], - indtype=theano.config.floatX, inscale=255.) + indtype=theano.config.floatX, inscale=255., maxsize=maxsize) -nist_all = FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')], +nist_all = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')], train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')], test_data = [os.path.join(DATA_PATH,'test_data.ft')], test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')], valid_data = [os.path.join(DATA_PATH,'valid_data.ft')], valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')], - indtype=theano.config.floatX, inscale=255.) + indtype=theano.config.floatX, inscale=255., maxsize=maxsize) -ocr = FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')], +ocr = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')], train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')], test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')], test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')], valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')], valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')], - indtype=theano.config.floatX, inscale=255.) + indtype=theano.config.floatX, inscale=255., maxsize=maxsize) -nist_P07 = FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)], +nist_P07 = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)], train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)], test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')], test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')], valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')], valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')], - indtype=theano.config.floatX, inscale=255.) + indtype=theano.config.floatX, inscale=255., maxsize=maxsize) -mnist = GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz')) +mnist = lambda maxsize=None: GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'), + maxsize=maxsize)
--- a/datasets/ftfile.py Tue Mar 16 14:46:25 2010 -0400 +++ b/datasets/ftfile.py Tue Mar 16 18:51:27 2010 -0400 @@ -89,57 +89,58 @@ return res class FTSource(object): - def __init__(self, file, skip=0, size=None, dtype=None, scale=1): + def __init__(self, file, skip=0, size=None, maxsize=None, + dtype=None, scale=1): r""" Create a data source from a possible subset of a .ft file. Parameters: - `file` (string) -- the filename - `skip` (int, optional) -- amount of examples to skip from - the start of the file. If - negative, skips filesize - skip. - `size` (int, optional) -- truncates number of examples - read (after skipping). If - negative truncates to - filesize - size - (also after skipping). - `dtype` (dtype, optional) -- convert the data to this - dtype after reading. - `scale` (number, optional) -- scale (that is divide) the - data by this number (after - dtype conversion, if any). + `file` -- (string) the filename + `skip` -- (int, optional) amount of examples to skip from + the start of the file. If negative, skips + filesize - skip. + `size` -- (int, optional) truncates number of examples + read (after skipping). If negative truncates to + filesize - size (also after skipping). + `maxsize` -- (int, optional) the maximum size of the file + `dtype` -- (dtype, optional) convert the data to this + dtype after reading. + `scale` -- (number, optional) scale (that is divide) the + data by this number (after dtype conversion, if + any). Tests: - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) """ self.file = file self.skip = skip self.size = size self.dtype = dtype self.scale = scale + self.maxsize = maxsize def open(self): r""" Returns an FTFile that corresponds to this dataset. Tests: - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') - >>> f = s.open() - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) - >>> len(s.open().read(2)) - 1 - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) - >>> s.open().size - 1000 - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) - >>> s.open().size - 1 - >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) - >>> s.open().size - 58636 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') + >>> f = s.open() + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) + >>> len(s.open().read(2)) + 1 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) + >>> s.open().size + 1000 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) + >>> s.open().size + 1 + >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) + >>> s.open().size + 58636 """ f = FTFile(self.file, scale=self.scale, dtype=self.dtype) if self.skip != 0: @@ -147,19 +148,25 @@ if self.size is not None and self.size < f.size: if self.size < 0: f.size += self.size + if f.size < 0: + f.size = 0 else: f.size = self.size + if self.maxsize is not None and f.size > self.maxsize: + f.size = self.maxsize return f class FTData(object): r""" This is a list of FTSources. """ - def __init__(self, datafiles, labelfiles, skip=0, size=None, + def __init__(self, datafiles, labelfiles, skip=0, size=None, maxsize=None, inscale=1, indtype=None, outscale=1, outdtype=None): - self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype) + if maxsize is not None: + maxsize /= len(datafiles) + self.inputs = [FTSource(f, skip, size, maxsize, scale=inscale, dtype=indtype) for f in datafiles] - self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype) + self.outputs = [FTSource(f, skip, size, maxsize, scale=outscale, dtype=outdtype) for f in labelfiles] def open_inputs(self): @@ -170,7 +177,9 @@ class FTDataSet(DataSet): - def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1): + def __init__(self, train_data, train_lbl, test_data, test_lbl, + valid_data=None, valid_lbl=None, indtype=None, outdtype=None, + inscale=1, outscale=1, maxsize=None): r""" Defines a DataSet from a bunch of files. @@ -184,6 +193,7 @@ (optional) `indtype`, `outdtype`, -- see FTSource.__init__() `inscale`, `outscale` (optional) + `maxsize` -- maximum size of the set returned If `valid_data` and `valid_labels` are not supplied then a sample @@ -191,21 +201,26 @@ set. """ if valid_data is None: - total_valid_size = sum(FTFile(td).size for td in test_data) + total_valid_size = min(sum(FTFile(td).size for td in test_data), maxsize) valid_size = total_valid_size/len(train_data) self._train = FTData(train_data, train_lbl, size=-valid_size, - inscale=inscale, outscale=outscale, indtype=indtype, - outdtype=outdtype) + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype, + maxsize=maxsize) self._valid = FTData(train_data, train_lbl, skip=-valid_size, - inscale=inscale, outscale=outscale, indtype=indtype, - outdtype=outdtype) + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype, + maxsize=maxsize) else: - self._train = FTData(train_data, train_lbl,inscale=inscale, - outscale=outscale, indtype=indtype, outdtype=outdtype) - self._valid = FTData(valid_data, valid_lbl,inscale=inscale, - outscale=outscale, indtype=indtype, outdtype=outdtype) - self._test = FTData(test_data, test_lbl,inscale=inscale, - outscale=outscale, indtype=indtype, outdtype=outdtype) + self._train = FTData(train_data, train_lbl, maxsize=maxsize, + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype) + self._valid = FTData(valid_data, valid_lbl, maxsize=maxsize, + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype) + self._test = FTData(test_data, test_lbl, maxsize=maxsize, + inscale=inscale, outscale=outscale, + indtype=indtype, outdtype=outdtype) def _return_it(self, batchsize, bufsize, ftdata): return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),
--- a/datasets/gzpklfile.py Tue Mar 16 14:46:25 2010 -0400 +++ b/datasets/gzpklfile.py Tue Mar 16 18:51:27 2010 -0400 @@ -19,8 +19,9 @@ return res class GzpklDataSet(DataSet): - def __init__(self, fname): + def __init__(self, fname, maxsize): self._fname = fname + self.maxsize = maxsize self._train = 0 self._valid = 1 self._test = 2 @@ -35,5 +36,5 @@ def _return_it(self, batchsz, bufsz, id): if not hasattr(self, 'datas'): self._load() - return izip(DataIterator([ArrayFile(self.datas[id][0])], batchsz, bufsz), - DataIterator([ArrayFile(self.datas[id][1])], batchsz, bufsz)) + return izip(DataIterator([ArrayFile(self.datas[id][0][:maxsize])], batchsz, bufsz), + DataIterator([ArrayFile(self.datas[id][1][:maxsize])], batchsz, bufsz))