diff datasets/ftfile.py @ 262:716c99f4eb3a

merge
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Wed, 17 Mar 2010 16:41:51 -0400
parents 966272e7f14b
children a92ec9939e4f
line wrap: on
line diff
--- a/datasets/ftfile.py	Wed Mar 17 16:41:16 2010 -0400
+++ b/datasets/ftfile.py	Wed Mar 17 16:41:51 2010 -0400
@@ -89,57 +89,58 @@
         return res
 
 class FTSource(object):
-    def __init__(self, file, skip=0, size=None, dtype=None, scale=1):
+    def __init__(self, file, skip=0, size=None, maxsize=None, 
+                 dtype=None, scale=1):
         r"""
         Create a data source from a possible subset of a .ft file.
 
         Parameters:
-            `file` (string) -- the filename
-            `skip` (int, optional) -- amount of examples to skip from
-                                      the start of the file.  If
-                                      negative, skips filesize - skip.
-            `size` (int, optional) -- truncates number of examples
-                                      read (after skipping).  If
-                                      negative truncates to 
-                                      filesize - size 
-                                      (also after skipping).
-            `dtype` (dtype, optional) -- convert the data to this
-                                         dtype after reading.
-            `scale` (number, optional) -- scale (that is divide) the
-                                          data by this number (after
-                                          dtype conversion, if any).
+            `file` -- (string) the filename
+            `skip` -- (int, optional) amount of examples to skip from
+                      the start of the file.  If negative, skips
+                      filesize - skip.
+            `size` -- (int, optional) truncates number of examples
+                      read (after skipping).  If negative truncates to
+                      filesize - size (also after skipping).
+            `maxsize` -- (int, optional) the maximum size of the file
+            `dtype` -- (dtype, optional) convert the data to this
+                       dtype after reading.
+            `scale` -- (number, optional) scale (that is divide) the
+                       data by this number (after dtype conversion, if
+                       any).
 
         Tests:
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
         """
         self.file = file
         self.skip = skip
         self.size = size
         self.dtype = dtype
         self.scale = scale
+        self.maxsize = maxsize
     
     def open(self):
         r"""
         Returns an FTFile that corresponds to this dataset.
         
         Tests:
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
-           >>> f = s.open()
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
-           >>> len(s.open().read(2))
-           1
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
-           >>> s.open().size
-           1000
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
-           >>> s.open().size
-           1
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
-           >>> s.open().size
-           58636
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+        >>> f = s.open()
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
+        >>> len(s.open().read(2))
+        1
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
+        >>> s.open().size
+        1000
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
+        >>> s.open().size
+        1
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
+        >>> s.open().size
+        58636
         """
         f = FTFile(self.file, scale=self.scale, dtype=self.dtype)
         if self.skip != 0:
@@ -147,19 +148,25 @@
         if self.size is not None and self.size < f.size:
             if self.size < 0:
                 f.size += self.size
+                if f.size < 0:
+                    f.size = 0
             else:
                 f.size = self.size
+        if self.maxsize is not None and f.size > self.maxsize:
+            f.size = self.maxsize
         return f
 
 class FTData(object):
     r"""
     This is a list of FTSources.
     """
-    def __init__(self, datafiles, labelfiles, skip=0, size=None,
+    def __init__(self, datafiles, labelfiles, skip=0, size=None, maxsize=None,
                  inscale=1, indtype=None, outscale=1, outdtype=None):
-        self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype)
+        if maxsize is not None:
+            maxsize /= len(datafiles)
+        self.inputs = [FTSource(f, skip, size, maxsize, scale=inscale, dtype=indtype)
                        for f in  datafiles]
-        self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype)
+        self.outputs = [FTSource(f, skip, size, maxsize, scale=outscale, dtype=outdtype)
                         for f in labelfiles]
 
     def open_inputs(self):
@@ -170,7 +177,9 @@
     
 
 class FTDataSet(DataSet):
-    def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1):
+    def __init__(self, train_data, train_lbl, test_data, test_lbl, 
+                 valid_data=None, valid_lbl=None, indtype=None, outdtype=None,
+                 inscale=1, outscale=1, maxsize=None):
         r"""
         Defines a DataSet from a bunch of files.
         
@@ -184,6 +193,7 @@
                                            (optional)
            `indtype`, `outdtype`,  -- see FTSource.__init__()
            `inscale`, `outscale`      (optional)
+           `maxsize` -- maximum size of the set returned
                                                              
 
         If `valid_data` and `valid_labels` are not supplied then a sample
@@ -191,21 +201,26 @@
         set.
         """
         if valid_data is None:
-            total_valid_size = sum(FTFile(td).size for td in test_data)
+            total_valid_size = min(sum(FTFile(td).size for td in test_data), maxsize)
             valid_size = total_valid_size/len(train_data)
             self._train = FTData(train_data, train_lbl, size=-valid_size,
-                    inscale=inscale, outscale=outscale, indtype=indtype,
-                    outdtype=outdtype)
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype,
+                                 maxsize=maxsize)
             self._valid = FTData(train_data, train_lbl, skip=-valid_size,
-                    inscale=inscale, outscale=outscale, indtype=indtype, 
-                    outdtype=outdtype)
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype,
+                                 maxsize=maxsize)
         else:
-            self._train = FTData(train_data, train_lbl,inscale=inscale,
-                    outscale=outscale, indtype=indtype, outdtype=outdtype)
-            self._valid = FTData(valid_data, valid_lbl,inscale=inscale,
-                    outscale=outscale, indtype=indtype, outdtype=outdtype)
-        self._test = FTData(test_data, test_lbl,inscale=inscale,
-                outscale=outscale, indtype=indtype, outdtype=outdtype)
+            self._train = FTData(train_data, train_lbl, maxsize=maxsize,
+                                 inscale=inscale, outscale=outscale, 
+                                 indtype=indtype, outdtype=outdtype)
+            self._valid = FTData(valid_data, valid_lbl, maxsize=maxsize,
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype)
+        self._test = FTData(test_data, test_lbl, maxsize=maxsize,
+                            inscale=inscale, outscale=outscale,
+                            indtype=indtype, outdtype=outdtype)
 
     def _return_it(self, batchsize, bufsize, ftdata):
         return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),