view _test_filetensor.py @ 453:ce6b4fd3ab29

Fixed typo in help
author delallea@valhalla.apstat.com
date Thu, 04 Sep 2008 13:48:47 -0400
parents 040cb796f4e0
children
line wrap: on
line source

from filetensor import *
import filetensor

import unittest
import os

class T(unittest.TestCase):
    fname = '/tmp/some_mat'

    def setUp(self):
        #TODO: test that /tmp/some_mat does not exist
        try:
            os.stat(self.fname)
        except OSError:
            return #assume file was not found
        raise Exception('autotest file "%s" exists!' % self.fname)

    def tearDown(self):
        os.remove(self.fname)

    def test_file(self):
        gen = numpy.random.rand(1)
        f = file(self.fname, 'w');
        write(f, gen)
        f.flush()
        f = file(self.fname, 'r');
        mat = read(f, None, debug=False) #load from filename
        self.failUnless(gen.shape == mat.shape)
        self.failUnless(numpy.all(gen == mat))

    def test_filename(self):
        gen = numpy.random.rand(1)
        f = file(self.fname, 'w')
        write(f, gen)
        f.close()
        f = file(self.fname, 'r')
        mat = read(f, None, debug=False) #load from filename
        f.close()
        self.failUnless(gen.shape == mat.shape)
        self.failUnless(numpy.all(gen == mat))

    def testNd(self):
        """shape and values are stored correctly for tensors of rank 0 to 5"""
        whole_shape = [5, 6, 7, 8, 9]
        for i in xrange(5):
            gen = numpy.asarray(numpy.random.rand(*whole_shape[:i]))
            f = file(self.fname, 'w');
            write(f, gen)
            f.flush()
            f = file(self.fname, 'r');
            mat = read(f, None, debug=False) #load from filename
            self.failUnless(gen.shape == mat.shape)
            self.failUnless(numpy.all(gen == mat))

    def test_dtypes(self):
        """shape and values are stored correctly for all dtypes """
        for dtype in filetensor._dtype_magic:
            gen = numpy.asarray(
                    numpy.random.rand(4, 5, 2, 1) * 100,
                    dtype=dtype)
            f = file(self.fname, 'w');
            write(f, gen)
            f.flush()
            f = file(self.fname, 'r');
            mat = read(f, None, debug=False) #load from filename
            self.failUnless(gen.dtype == mat.dtype)
            self.failUnless(gen.shape == mat.shape)
            self.failUnless(numpy.all(gen == mat))

    def test_dtype_invalid(self):
        gen = numpy.zeros((3,4), dtype='uint16') #an unsupported dtype
        f = file(self.fname, 'w')
        passed = False
        try:
            write(f, gen)
        except TypeError, e:
            if e[0].startswith('Invalid ndarray dtype'):
                passed = True
        f.close()
        self.failUnless(passed)
        

if __name__ == '__main__':
    unittest.main()

    #a small test script, starts by reading sys.argv[1]
    #print 'rval', rval.shape, rval.size

    if 0:
        write(f, rval)
        print ''
        f.close()
        f = file('/tmp/some_mat', 'r');
        rval2 = read(f) #load from file handle
        print 'rval2', rval2.shape, rval2.size

        assert rval.dtype == rval2.dtype
        assert rval.shape == rval2.shape
        assert numpy.all(rval == rval2)
        print 'ok'

    def _unused():
        f.seek(0,2) #seek to end
        f_len =  f.tell()
        f.seek(f_data_start,0) #seek back to where we were

        if debug: print 'length:', f_len


        f_data_bytes = (f_len - f_data_start)

        if debug: print 'data bytes according to header: ', dim_size * elsize
        if debug: print 'data bytes according to file  : ', f_data_bytes

        if debug: print 'reading data...'
        sys.stdout.flush()

    def read_ndarray(f, dim, dtype):
        return numpy.fromfile(f, dtype=dtype, count=_prod(dim)).reshape(dim)