diff test_filetensor.py @ 248:82ba488b2c24

polished filetensor a little
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 03 Jun 2008 13:14:45 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test_filetensor.py	Tue Jun 03 13:14:45 2008 -0400
@@ -0,0 +1,116 @@
+from filetensor import *
+import filetensor
+
+import unittest
+import os
+
+class T(unittest.TestCase):
+    fname = '/tmp/some_mat'
+
+    def setUp(self):
+        #TODO: test that /tmp/some_mat does not exist
+        try:
+            os.stat(self.fname)
+        except OSError:
+            return #assume file was not found
+        raise Exception('autotest file "%s" exists!' % self.fname)
+
+    def tearDown(self):
+        os.remove(self.fname)
+
+    def test_file(self):
+        gen = numpy.random.rand(1)
+        f = file(self.fname, 'w');
+        write(f, gen)
+        f.flush()
+        f = file(self.fname, 'r');
+        mat = read(f, None, debug=False) #load from filename
+        self.failUnless(gen.shape == mat.shape)
+        self.failUnless(numpy.all(gen == mat))
+
+    def test_filename(self):
+        gen = numpy.random.rand(1)
+        write(self.fname, gen)
+        mat = read(self.fname, None, debug=False) #load from filename
+        self.failUnless(gen.shape == mat.shape)
+        self.failUnless(numpy.all(gen == mat))
+
+    def testNd(self):
+        """shape and values are stored correctly for tensors of rank 0 to 5"""
+        whole_shape = [5, 6, 7, 8, 9]
+        for i in xrange(5):
+            gen = numpy.asarray(numpy.random.rand(*whole_shape[:i]))
+            f = file(self.fname, 'w');
+            write(f, gen)
+            f.flush()
+            f = file(self.fname, 'r');
+            mat = read(f, None, debug=False) #load from filename
+            self.failUnless(gen.shape == mat.shape)
+            self.failUnless(numpy.all(gen == mat))
+
+    def test_dtypes(self):
+        """shape and values are stored correctly for all dtypes """
+        for dtype in filetensor._dtype_magic:
+            gen = numpy.asarray(
+                    numpy.random.rand(4, 5, 2, 1) * 100,
+                    dtype=dtype)
+            f = file(self.fname, 'w');
+            write(f, gen)
+            f.flush()
+            f = file(self.fname, 'r');
+            mat = read(f, None, debug=False) #load from filename
+            self.failUnless(gen.dtype == mat.dtype)
+            self.failUnless(gen.shape == mat.shape)
+            self.failUnless(numpy.all(gen == mat))
+
+    def test_dtype_invalid(self):
+        gen = numpy.zeros((3,4), dtype='uint16') #an unsupported dtype
+        f = file(self.fname, 'w')
+        passed = False
+        try:
+            write(f, gen)
+        except TypeError, e:
+            if e[0].startswith('Invalid ndarray dtype'):
+                passed = True
+        f.close()
+        self.failUnless(passed)
+        
+
+if __name__ == '__main__':
+    unittest.main()
+
+    #a small test script, starts by reading sys.argv[1]
+    #print 'rval', rval.shape, rval.size
+
+    if 0:
+        write(f, rval)
+        print ''
+        f.close()
+        f = file('/tmp/some_mat', 'r');
+        rval2 = read(f) #load from file handle
+        print 'rval2', rval2.shape, rval2.size
+
+        assert rval.dtype == rval2.dtype
+        assert rval.shape == rval2.shape
+        assert numpy.all(rval == rval2)
+        print 'ok'
+
+    def _unused():
+        f.seek(0,2) #seek to end
+        f_len =  f.tell()
+        f.seek(f_data_start,0) #seek back to where we were
+
+        if debug: print 'length:', f_len
+
+
+        f_data_bytes = (f_len - f_data_start)
+
+        if debug: print 'data bytes according to header: ', dim_size * elsize
+        if debug: print 'data bytes according to file  : ', f_data_bytes
+
+        if debug: print 'reading data...'
+        sys.stdout.flush()
+
+    def read_ndarray(f, dim, dtype):
+        return numpy.fromfile(f, dtype=dtype, count=_prod(dim)).reshape(dim)
+