changeset 248:82ba488b2c24

polished filetensor a little
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 03 Jun 2008 13:14:45 -0400
parents c702abb7f875
children e93e511deb9a
files filetensor.py test_filetensor.py
diffstat 2 files changed, 146 insertions(+), 57 deletions(-) [+]
line wrap: on
line diff
--- a/filetensor.py	Mon Jun 02 17:09:58 2008 -0400
+++ b/filetensor.py	Tue Jun 03 13:14:45 2008 -0400
@@ -16,11 +16,14 @@
 
 For rank >= 3, the number of dimensions matches the rank exactly.
 
+
+@todo: add complex type support
+
 """
 import sys
 import numpy
 
-def prod(lst):
+def _prod(lst):
     p = 1
     for l in lst:
         p *= l
@@ -28,7 +31,7 @@
 
 _magic_dtype = {
         0x1E3D4C51 : ('float32', 4),
-        0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix?
+        #0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix?
         0x1E3D4C53 : ('float64', 8),
         0x1E3D4C54 : ('int32', 4),
         0x1E3D4C55 : ('uint8', 1),
@@ -36,41 +39,13 @@
         }
 _dtype_magic = {
         'float32': 0x1E3D4C51,
-        'packed matrix': 0x1E3D4C52,
+        #'packed matrix': 0x1E3D4C52,
         'float64': 0x1E3D4C53,
         'int32': 0x1E3D4C54,
         'uint8': 0x1E3D4C55,
         'int16': 0x1E3D4C56
         }
 
-def _unused():
-    f.seek(0,2) #seek to end
-    f_len =  f.tell()
-    f.seek(f_data_start,0) #seek back to where we were
-
-    if debug: print 'length:', f_len
-
-
-    f_data_bytes = (f_len - f_data_start)
-
-    if debug: print 'data bytes according to header: ', dim_size * elsize
-    if debug: print 'data bytes according to file  : ', f_data_bytes
-
-    if debug: print 'reading data...'
-    sys.stdout.flush()
-
-def _write_int32(f, i):
-    i_array = numpy.asarray(i, dtype='int32')
-    if 0: print 'writing int32', i, i_array
-    i_array.tofile(f)
-def _read_int32(f):
-    s = f.read(4)
-    s_array = numpy.fromstring(s, dtype='int32')
-    return s_array.item()
-
-def read_ndarray(f, dim, dtype):
-    return numpy.fromfile(f, dtype=dtype, count=prod(dim)).reshape(dim)
-
 #
 # TODO: implement item selection:
 #  e.g. load('some mat', subtensor=(:6, 2:5))
@@ -94,6 +69,10 @@
     particular type of subtensor is supported.
 
     """
+    def _read_int32(f):
+        s = f.read(4)
+        s_array = numpy.fromstring(s, dtype='int32')
+        return s_array.item()
 
     if isinstance(f, str):
         if debug: print 'f', f
@@ -115,55 +94,49 @@
 
     #what are the dimensions of the tensor?
     dim = numpy.fromfile(f, dtype='int32', count=max(ndim,3))[:ndim]
-    dim_size = prod(dim)
+    dim_size = _prod(dim)
     if debug: print 'header dim', dim, dim_size
 
     rval = None
     if subtensor is None:
-        rval = read_ndarray(f, dim, magic_t)
+        rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim)
     elif isinstance(subtensor, slice):
         if subtensor.step not in (None, 1):
             raise NotImplementedError('slice with step', subtensor.step)
         if subtensor.start not in (None, 0):
-            bytes_per_row = prod(dim[1:]) * elsize
+            bytes_per_row = _prod(dim[1:]) * elsize
             raise NotImplementedError('slice with start', subtensor.start)
         dim[0] = min(dim[0], subtensor.stop)
-        rval = read_ndarray(f, dim, magic_t)
+        rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim)
     else:
         raise NotImplementedError('subtensor access not written yet:', subtensor) 
 
     return rval
 
 def write(f, mat):
+    """Write a numpy.ndarray to file.
+
+    If 'f' is a string, then it will be interpreted as a filename. This filename
+    will be opened in 'w+' mode, and (automatically) closed at the end of the function.
+    """
+    def _write_int32(f, i):
+        i_array = numpy.asarray(i, dtype='int32')
+        if 0: print 'writing int32', i, i_array
+        i_array.tofile(f)
     if isinstance(f, str):
-        f = file(f, 'w')
+        f = file(f, 'w+')
 
-    _write_int32(f, _dtype_magic[str(mat.dtype)])
+    try:
+        _write_int32(f, _dtype_magic[str(mat.dtype)])
+    except KeyError:
+        raise TypeError('Invalid ndarray dtype for filetensor format', mat.dtype)
+
     _write_int32(f, len(mat.shape))
     shape = mat.shape
     if len(shape) < 3:
         shape = list(shape) + [1] * (3 - len(shape))
-    print 'writing shape =', shape
+    if 0: print 'writing shape =', shape
     for sh in shape:
         _write_int32(f, sh)
     mat.tofile(f)
 
-if __name__ == '__main__':
-    #a small test script, starts by reading sys.argv[1]
-    rval = read(sys.argv[1], None, debug=True) #load from filename
-    print 'rval', rval.shape, rval.size
-
-    if 0:
-        f = file('/tmp/some_mat', 'w');
-        write(f, rval)
-        print ''
-        f.close()
-        f = file('/tmp/some_mat', 'r');
-        rval2 = read(f) #load from file handle
-        print 'rval2', rval2.shape, rval2.size
-
-        assert rval.dtype == rval2.dtype
-        assert rval.shape == rval2.shape
-        assert numpy.all(rval == rval2)
-        print 'ok'
-
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test_filetensor.py	Tue Jun 03 13:14:45 2008 -0400
@@ -0,0 +1,116 @@
+from filetensor import *
+import filetensor
+
+import unittest
+import os
+
+class T(unittest.TestCase):
+    fname = '/tmp/some_mat'
+
+    def setUp(self):
+        #TODO: test that /tmp/some_mat does not exist
+        try:
+            os.stat(self.fname)
+        except OSError:
+            return #assume file was not found
+        raise Exception('autotest file "%s" exists!' % self.fname)
+
+    def tearDown(self):
+        os.remove(self.fname)
+
+    def test_file(self):
+        gen = numpy.random.rand(1)
+        f = file(self.fname, 'w');
+        write(f, gen)
+        f.flush()
+        f = file(self.fname, 'r');
+        mat = read(f, None, debug=False) #load from filename
+        self.failUnless(gen.shape == mat.shape)
+        self.failUnless(numpy.all(gen == mat))
+
+    def test_filename(self):
+        gen = numpy.random.rand(1)
+        write(self.fname, gen)
+        mat = read(self.fname, None, debug=False) #load from filename
+        self.failUnless(gen.shape == mat.shape)
+        self.failUnless(numpy.all(gen == mat))
+
+    def testNd(self):
+        """shape and values are stored correctly for tensors of rank 0 to 5"""
+        whole_shape = [5, 6, 7, 8, 9]
+        for i in xrange(5):
+            gen = numpy.asarray(numpy.random.rand(*whole_shape[:i]))
+            f = file(self.fname, 'w');
+            write(f, gen)
+            f.flush()
+            f = file(self.fname, 'r');
+            mat = read(f, None, debug=False) #load from filename
+            self.failUnless(gen.shape == mat.shape)
+            self.failUnless(numpy.all(gen == mat))
+
+    def test_dtypes(self):
+        """shape and values are stored correctly for all dtypes """
+        for dtype in filetensor._dtype_magic:
+            gen = numpy.asarray(
+                    numpy.random.rand(4, 5, 2, 1) * 100,
+                    dtype=dtype)
+            f = file(self.fname, 'w');
+            write(f, gen)
+            f.flush()
+            f = file(self.fname, 'r');
+            mat = read(f, None, debug=False) #load from filename
+            self.failUnless(gen.dtype == mat.dtype)
+            self.failUnless(gen.shape == mat.shape)
+            self.failUnless(numpy.all(gen == mat))
+
+    def test_dtype_invalid(self):
+        gen = numpy.zeros((3,4), dtype='uint16') #an unsupported dtype
+        f = file(self.fname, 'w')
+        passed = False
+        try:
+            write(f, gen)
+        except TypeError, e:
+            if e[0].startswith('Invalid ndarray dtype'):
+                passed = True
+        f.close()
+        self.failUnless(passed)
+        
+
+if __name__ == '__main__':
+    unittest.main()
+
+    #a small test script, starts by reading sys.argv[1]
+    #print 'rval', rval.shape, rval.size
+
+    if 0:
+        write(f, rval)
+        print ''
+        f.close()
+        f = file('/tmp/some_mat', 'r');
+        rval2 = read(f) #load from file handle
+        print 'rval2', rval2.shape, rval2.size
+
+        assert rval.dtype == rval2.dtype
+        assert rval.shape == rval2.shape
+        assert numpy.all(rval == rval2)
+        print 'ok'
+
+    def _unused():
+        f.seek(0,2) #seek to end
+        f_len =  f.tell()
+        f.seek(f_data_start,0) #seek back to where we were
+
+        if debug: print 'length:', f_len
+
+
+        f_data_bytes = (f_len - f_data_start)
+
+        if debug: print 'data bytes according to header: ', dim_size * elsize
+        if debug: print 'data bytes according to file  : ', f_data_bytes
+
+        if debug: print 'reading data...'
+        sys.stdout.flush()
+
+    def read_ndarray(f, dim, dtype):
+        return numpy.fromfile(f, dtype=dtype, count=_prod(dim)).reshape(dim)
+