changeset 257:966272e7f14b

Make the datasets lazy-loading and add a maxsize parameter.
author Arnaud Bergeron <abergeron@gmail.com>
date Tue, 16 Mar 2010 18:51:27 -0400
parents 7e6fecabb656
children c2fae7b96769
files datasets/defs.py datasets/ftfile.py datasets/gzpklfile.py
diffstat 3 files changed, 81 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/datasets/defs.py	Tue Mar 16 14:46:25 2010 -0400
+++ b/datasets/defs.py	Tue Mar 16 18:51:27 2010 -0400
@@ -11,44 +11,45 @@
 NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/')
 DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/')
 
-nist_digits = FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
+nist_digits = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
                         train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')],
                         test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')],
                         test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')],
-                        indtype=theano.config.floatX, inscale=255.)
-nist_lower = FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
+                        indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
+nist_lower = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
                         train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')],
                         test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')],
                         test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')],
-                        indtype=theano.config.floatX, inscale=255.)
-nist_upper = FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
+                        indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
+nist_upper = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
                         train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')],
                         test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')],
                         test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')],
-                        indtype=theano.config.floatX, inscale=255.)
+                        indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
 
-nist_all = FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
+nist_all = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
                      train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')],
                      test_data = [os.path.join(DATA_PATH,'test_data.ft')],
                      test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')],
                      valid_data = [os.path.join(DATA_PATH,'valid_data.ft')],
                      valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')],
-                     indtype=theano.config.floatX, inscale=255.)
+                     indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
 
-ocr = FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
+ocr = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
                 train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')],
                 test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')],
                 test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')],
                 valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')],
                 valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')],
-                indtype=theano.config.floatX, inscale=255.)
+                indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
 
-nist_P07 = FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
+nist_P07 = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
                      train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)],
                      test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')],
                      test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')],
                      valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')],
                      valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')],
-                     indtype=theano.config.floatX, inscale=255.)
+                     indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
 
-mnist = GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'))
+mnist = lambda maxsize=None: GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'),
+                                          maxsize=maxsize)
--- a/datasets/ftfile.py	Tue Mar 16 14:46:25 2010 -0400
+++ b/datasets/ftfile.py	Tue Mar 16 18:51:27 2010 -0400
@@ -89,57 +89,58 @@
         return res
 
 class FTSource(object):
-    def __init__(self, file, skip=0, size=None, dtype=None, scale=1):
+    def __init__(self, file, skip=0, size=None, maxsize=None, 
+                 dtype=None, scale=1):
         r"""
         Create a data source from a possible subset of a .ft file.
 
         Parameters:
-            `file` (string) -- the filename
-            `skip` (int, optional) -- amount of examples to skip from
-                                      the start of the file.  If
-                                      negative, skips filesize - skip.
-            `size` (int, optional) -- truncates number of examples
-                                      read (after skipping).  If
-                                      negative truncates to 
-                                      filesize - size 
-                                      (also after skipping).
-            `dtype` (dtype, optional) -- convert the data to this
-                                         dtype after reading.
-            `scale` (number, optional) -- scale (that is divide) the
-                                          data by this number (after
-                                          dtype conversion, if any).
+            `file` -- (string) the filename
+            `skip` -- (int, optional) amount of examples to skip from
+                      the start of the file.  If negative, skips
+                      filesize - skip.
+            `size` -- (int, optional) truncates number of examples
+                      read (after skipping).  If negative truncates to
+                      filesize - size (also after skipping).
+            `maxsize` -- (int, optional) the maximum size of the file
+            `dtype` -- (dtype, optional) convert the data to this
+                       dtype after reading.
+            `scale` -- (number, optional) scale (that is divide) the
+                       data by this number (after dtype conversion, if
+                       any).
 
         Tests:
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
         """
         self.file = file
         self.skip = skip
         self.size = size
         self.dtype = dtype
         self.scale = scale
+        self.maxsize = maxsize
     
     def open(self):
         r"""
         Returns an FTFile that corresponds to this dataset.
         
         Tests:
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
-           >>> f = s.open()
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
-           >>> len(s.open().read(2))
-           1
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
-           >>> s.open().size
-           1000
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
-           >>> s.open().size
-           1
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
-           >>> s.open().size
-           58636
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+        >>> f = s.open()
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
+        >>> len(s.open().read(2))
+        1
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
+        >>> s.open().size
+        1000
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
+        >>> s.open().size
+        1
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
+        >>> s.open().size
+        58636
         """
         f = FTFile(self.file, scale=self.scale, dtype=self.dtype)
         if self.skip != 0:
@@ -147,19 +148,25 @@
         if self.size is not None and self.size < f.size:
             if self.size < 0:
                 f.size += self.size
+                if f.size < 0:
+                    f.size = 0
             else:
                 f.size = self.size
+        if self.maxsize is not None and f.size > self.maxsize:
+            f.size = self.maxsize
         return f
 
 class FTData(object):
     r"""
     This is a list of FTSources.
     """
-    def __init__(self, datafiles, labelfiles, skip=0, size=None,
+    def __init__(self, datafiles, labelfiles, skip=0, size=None, maxsize=None,
                  inscale=1, indtype=None, outscale=1, outdtype=None):
-        self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype)
+        if maxsize is not None:
+            maxsize /= len(datafiles)
+        self.inputs = [FTSource(f, skip, size, maxsize, scale=inscale, dtype=indtype)
                        for f in  datafiles]
-        self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype)
+        self.outputs = [FTSource(f, skip, size, maxsize, scale=outscale, dtype=outdtype)
                         for f in labelfiles]
 
     def open_inputs(self):
@@ -170,7 +177,9 @@
     
 
 class FTDataSet(DataSet):
-    def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1):
+    def __init__(self, train_data, train_lbl, test_data, test_lbl, 
+                 valid_data=None, valid_lbl=None, indtype=None, outdtype=None,
+                 inscale=1, outscale=1, maxsize=None):
         r"""
         Defines a DataSet from a bunch of files.
         
@@ -184,6 +193,7 @@
                                            (optional)
            `indtype`, `outdtype`,  -- see FTSource.__init__()
            `inscale`, `outscale`      (optional)
+           `maxsize` -- maximum size of the set returned
                                                              
 
         If `valid_data` and `valid_labels` are not supplied then a sample
@@ -191,21 +201,26 @@
         set.
         """
         if valid_data is None:
-            total_valid_size = sum(FTFile(td).size for td in test_data)
+            total_valid_size = min(sum(FTFile(td).size for td in test_data), maxsize)
             valid_size = total_valid_size/len(train_data)
             self._train = FTData(train_data, train_lbl, size=-valid_size,
-                    inscale=inscale, outscale=outscale, indtype=indtype,
-                    outdtype=outdtype)
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype,
+                                 maxsize=maxsize)
             self._valid = FTData(train_data, train_lbl, skip=-valid_size,
-                    inscale=inscale, outscale=outscale, indtype=indtype, 
-                    outdtype=outdtype)
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype,
+                                 maxsize=maxsize)
         else:
-            self._train = FTData(train_data, train_lbl,inscale=inscale,
-                    outscale=outscale, indtype=indtype, outdtype=outdtype)
-            self._valid = FTData(valid_data, valid_lbl,inscale=inscale,
-                    outscale=outscale, indtype=indtype, outdtype=outdtype)
-        self._test = FTData(test_data, test_lbl,inscale=inscale,
-                outscale=outscale, indtype=indtype, outdtype=outdtype)
+            self._train = FTData(train_data, train_lbl, maxsize=maxsize,
+                                 inscale=inscale, outscale=outscale, 
+                                 indtype=indtype, outdtype=outdtype)
+            self._valid = FTData(valid_data, valid_lbl, maxsize=maxsize,
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype)
+        self._test = FTData(test_data, test_lbl, maxsize=maxsize,
+                            inscale=inscale, outscale=outscale,
+                            indtype=indtype, outdtype=outdtype)
 
     def _return_it(self, batchsize, bufsize, ftdata):
         return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),
--- a/datasets/gzpklfile.py	Tue Mar 16 14:46:25 2010 -0400
+++ b/datasets/gzpklfile.py	Tue Mar 16 18:51:27 2010 -0400
@@ -19,8 +19,9 @@
         return res
 
 class GzpklDataSet(DataSet):
-    def __init__(self, fname):
+    def __init__(self, fname, maxsize):
         self._fname = fname
+        self.maxsize = maxsize
         self._train = 0
         self._valid = 1
         self._test = 2
@@ -35,5 +36,5 @@
     def _return_it(self, batchsz, bufsz, id):
         if not hasattr(self, 'datas'):
             self._load()
-        return izip(DataIterator([ArrayFile(self.datas[id][0])], batchsz, bufsz),
-                    DataIterator([ArrayFile(self.datas[id][1])], batchsz, bufsz))
+        return izip(DataIterator([ArrayFile(self.datas[id][0][:maxsize])], batchsz, bufsz),
+                    DataIterator([ArrayFile(self.datas[id][1][:maxsize])], batchsz, bufsz))