diff datasets/ftfile.py @ 266:1e4e60ddadb1

Merge. Ah, et dans le dernier commit, j'avais oublié de mentionner que j'ai ajouté du code pour gérer l'isolation de différents clones pour rouler des expériences et modifier le code en même temps.
author fsavard
date Fri, 19 Mar 2010 10:56:16 -0400
parents 966272e7f14b
children a92ec9939e4f
line wrap: on
line diff
--- a/datasets/ftfile.py	Fri Mar 19 10:54:39 2010 -0400
+++ b/datasets/ftfile.py	Fri Mar 19 10:56:16 2010 -0400
@@ -89,57 +89,58 @@
         return res
 
 class FTSource(object):
-    def __init__(self, file, skip=0, size=None, dtype=None, scale=1):
+    def __init__(self, file, skip=0, size=None, maxsize=None, 
+                 dtype=None, scale=1):
         r"""
         Create a data source from a possible subset of a .ft file.
 
         Parameters:
-            `file` (string) -- the filename
-            `skip` (int, optional) -- amount of examples to skip from
-                                      the start of the file.  If
-                                      negative, skips filesize - skip.
-            `size` (int, optional) -- truncates number of examples
-                                      read (after skipping).  If
-                                      negative truncates to 
-                                      filesize - size 
-                                      (also after skipping).
-            `dtype` (dtype, optional) -- convert the data to this
-                                         dtype after reading.
-            `scale` (number, optional) -- scale (that is divide) the
-                                          data by this number (after
-                                          dtype conversion, if any).
+            `file` -- (string) the filename
+            `skip` -- (int, optional) amount of examples to skip from
+                      the start of the file.  If negative, skips
+                      filesize - skip.
+            `size` -- (int, optional) truncates number of examples
+                      read (after skipping).  If negative truncates to
+                      filesize - size (also after skipping).
+            `maxsize` -- (int, optional) the maximum size of the file
+            `dtype` -- (dtype, optional) convert the data to this
+                       dtype after reading.
+            `scale` -- (number, optional) scale (that is divide) the
+                       data by this number (after dtype conversion, if
+                       any).
 
         Tests:
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
         """
         self.file = file
         self.skip = skip
         self.size = size
         self.dtype = dtype
         self.scale = scale
+        self.maxsize = maxsize
     
     def open(self):
         r"""
         Returns an FTFile that corresponds to this dataset.
         
         Tests:
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
-           >>> f = s.open()
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
-           >>> len(s.open().read(2))
-           1
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
-           >>> s.open().size
-           1000
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
-           >>> s.open().size
-           1
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
-           >>> s.open().size
-           58636
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+        >>> f = s.open()
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
+        >>> len(s.open().read(2))
+        1
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
+        >>> s.open().size
+        1000
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
+        >>> s.open().size
+        1
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
+        >>> s.open().size
+        58636
         """
         f = FTFile(self.file, scale=self.scale, dtype=self.dtype)
         if self.skip != 0:
@@ -147,19 +148,25 @@
         if self.size is not None and self.size < f.size:
             if self.size < 0:
                 f.size += self.size
+                if f.size < 0:
+                    f.size = 0
             else:
                 f.size = self.size
+        if self.maxsize is not None and f.size > self.maxsize:
+            f.size = self.maxsize
         return f
 
 class FTData(object):
     r"""
     This is a list of FTSources.
     """
-    def __init__(self, datafiles, labelfiles, skip=0, size=None,
+    def __init__(self, datafiles, labelfiles, skip=0, size=None, maxsize=None,
                  inscale=1, indtype=None, outscale=1, outdtype=None):
-        self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype)
+        if maxsize is not None:
+            maxsize /= len(datafiles)
+        self.inputs = [FTSource(f, skip, size, maxsize, scale=inscale, dtype=indtype)
                        for f in  datafiles]
-        self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype)
+        self.outputs = [FTSource(f, skip, size, maxsize, scale=outscale, dtype=outdtype)
                         for f in labelfiles]
 
     def open_inputs(self):
@@ -170,7 +177,9 @@
     
 
 class FTDataSet(DataSet):
-    def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1):
+    def __init__(self, train_data, train_lbl, test_data, test_lbl, 
+                 valid_data=None, valid_lbl=None, indtype=None, outdtype=None,
+                 inscale=1, outscale=1, maxsize=None):
         r"""
         Defines a DataSet from a bunch of files.
         
@@ -184,6 +193,7 @@
                                            (optional)
            `indtype`, `outdtype`,  -- see FTSource.__init__()
            `inscale`, `outscale`      (optional)
+           `maxsize` -- maximum size of the set returned
                                                              
 
         If `valid_data` and `valid_labels` are not supplied then a sample
@@ -191,21 +201,26 @@
         set.
         """
         if valid_data is None:
-            total_valid_size = sum(FTFile(td).size for td in test_data)
+            total_valid_size = min(sum(FTFile(td).size for td in test_data), maxsize)
             valid_size = total_valid_size/len(train_data)
             self._train = FTData(train_data, train_lbl, size=-valid_size,
-                    inscale=inscale, outscale=outscale, indtype=indtype,
-                    outdtype=outdtype)
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype,
+                                 maxsize=maxsize)
             self._valid = FTData(train_data, train_lbl, skip=-valid_size,
-                    inscale=inscale, outscale=outscale, indtype=indtype, 
-                    outdtype=outdtype)
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype,
+                                 maxsize=maxsize)
         else:
-            self._train = FTData(train_data, train_lbl,inscale=inscale,
-                    outscale=outscale, indtype=indtype, outdtype=outdtype)
-            self._valid = FTData(valid_data, valid_lbl,inscale=inscale,
-                    outscale=outscale, indtype=indtype, outdtype=outdtype)
-        self._test = FTData(test_data, test_lbl,inscale=inscale,
-                outscale=outscale, indtype=indtype, outdtype=outdtype)
+            self._train = FTData(train_data, train_lbl, maxsize=maxsize,
+                                 inscale=inscale, outscale=outscale, 
+                                 indtype=indtype, outdtype=outdtype)
+            self._valid = FTData(valid_data, valid_lbl, maxsize=maxsize,
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype)
+        self._test = FTData(test_data, test_lbl, maxsize=maxsize,
+                            inscale=inscale, outscale=outscale,
+                            indtype=indtype, outdtype=outdtype)
 
     def _return_it(self, batchsize, bufsize, ftdata):
         return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),