diff datasets/ftfile.py @ 180:76bc047df5ee

Add dtype conversion and rescaling to the read path.
author Arnaud Bergeron <abergeron@gmail.com>
date Sat, 27 Feb 2010 16:50:16 -0500
parents 938bd350dbf0
children f0f47b045cbf
line wrap: on
line diff
--- a/datasets/ftfile.py	Sat Feb 27 16:07:09 2010 -0500
+++ b/datasets/ftfile.py	Sat Feb 27 16:50:16 2010 -0500
@@ -5,7 +5,7 @@
 from itertools import izip, imap
 
 class FTFile(object):
-    def __init__(self, fname):
+    def __init__(self, fname, scale=1, dtype=None):
         r"""
         Tests:
             >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
@@ -13,6 +13,8 @@
         self.file = open(fname, 'rb')
         self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
         self.size = self.dim[0]
+        self.scale = scale
+        self.dtype = dtype
 
     def skip(self, num):
         r"""
@@ -79,10 +81,15 @@
             num = self.size
         self.dim[0] = num
         self.size -= num
-        return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
+        res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
+        if self.dtype is not None:
+            res = res.astype(self.dtype)
+        if self.scale != 1:
+            res /= self.scale
+        return res
 
 class FTSource(object):
-    def __init__(self, file, skip=0, size=None):
+    def __init__(self, file, skip=0, size=None, dtype=None, scale=1):
         r"""
         Create a data source from a possible subset of a .ft file.
 
@@ -96,7 +103,12 @@
                                       negative truncates to 
                                       filesize - size 
                                       (also after skipping).
-        
+            `dtype` (dtype, optional) -- convert the data to this
+                                         dtype after reading.
+            `scale` (number, optional) -- scale (that is divide) the
+                                          data by this number (after
+                                          dtype conversion, if any).
+
         Tests:
            >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
            >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
@@ -106,6 +118,8 @@
         self.file = file
         self.skip = skip
         self.size = size
+        self.dtype = dtype
+        self.scale = scale
     
     def open(self):
         r"""
@@ -127,7 +141,7 @@
            >>> s.open().size
            58636
         """
-        f = FTFile(self.file)
+        f = FTFile(self.file, scale=self.scale, dtype=self.dtype)
         if self.skip != 0:
             f.skip(self.skip)
         if self.size is not None and self.size < f.size:
@@ -141,9 +155,12 @@
     r"""
     This is a list of FTSources.
     """
-    def __init__(self, datafiles, labelfiles, skip=0, size=None):
-        self.inputs = [FTSource(f, skip, size) for f in  datafiles]
-        self.outputs = [FTSource(f, skip, size) for f in labelfiles]
+    def __init__(self, datafiles, labelfiles, skip=0, size=None,
+                 inscale=1, indtype=None, outscale=1, outdtype=None):
+        self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype)
+                       for f in  datafiles]
+        self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype)
+                        for f in labelfiles]
 
     def open_inputs(self):
         return [f.open() for f in self.inputs]
@@ -153,7 +170,7 @@
     
 
 class FTDataSet(DataSet):
-    def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None):
+    def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1):
         r"""
         Defines a DataSet from a bunch of files.
         
@@ -165,6 +182,9 @@
                                          can differ from train.
            `valid_data`, `valid_labels` -- same thing again for validation.
                                            (optional)
+           `indtype`, `outdtype`,  -- see FTSource.__init__()
+           `inscale`, `outscale`      (optional)
+                                                             
 
         If `valid_data` and `valid_labels` are not supplied then a sample
         approximately equal in size to the test set is taken from the train