Mercurial > pylearn
view _test_filetensor.py @ 451:d99fefbc9324
Added a KL-divergence.
author | Joseph Turian <turian@gmail.com> |
---|---|
date | Thu, 04 Sep 2008 14:46:30 -0400 |
parents | 040cb796f4e0 |
children |
line wrap: on
line source
from filetensor import * import filetensor import unittest import os class T(unittest.TestCase): fname = '/tmp/some_mat' def setUp(self): #TODO: test that /tmp/some_mat does not exist try: os.stat(self.fname) except OSError: return #assume file was not found raise Exception('autotest file "%s" exists!' % self.fname) def tearDown(self): os.remove(self.fname) def test_file(self): gen = numpy.random.rand(1) f = file(self.fname, 'w'); write(f, gen) f.flush() f = file(self.fname, 'r'); mat = read(f, None, debug=False) #load from filename self.failUnless(gen.shape == mat.shape) self.failUnless(numpy.all(gen == mat)) def test_filename(self): gen = numpy.random.rand(1) f = file(self.fname, 'w') write(f, gen) f.close() f = file(self.fname, 'r') mat = read(f, None, debug=False) #load from filename f.close() self.failUnless(gen.shape == mat.shape) self.failUnless(numpy.all(gen == mat)) def testNd(self): """shape and values are stored correctly for tensors of rank 0 to 5""" whole_shape = [5, 6, 7, 8, 9] for i in xrange(5): gen = numpy.asarray(numpy.random.rand(*whole_shape[:i])) f = file(self.fname, 'w'); write(f, gen) f.flush() f = file(self.fname, 'r'); mat = read(f, None, debug=False) #load from filename self.failUnless(gen.shape == mat.shape) self.failUnless(numpy.all(gen == mat)) def test_dtypes(self): """shape and values are stored correctly for all dtypes """ for dtype in filetensor._dtype_magic: gen = numpy.asarray( numpy.random.rand(4, 5, 2, 1) * 100, dtype=dtype) f = file(self.fname, 'w'); write(f, gen) f.flush() f = file(self.fname, 'r'); mat = read(f, None, debug=False) #load from filename self.failUnless(gen.dtype == mat.dtype) self.failUnless(gen.shape == mat.shape) self.failUnless(numpy.all(gen == mat)) def test_dtype_invalid(self): gen = numpy.zeros((3,4), dtype='uint16') #an unsupported dtype f = file(self.fname, 'w') passed = False try: write(f, gen) except TypeError, e: if e[0].startswith('Invalid ndarray dtype'): passed = True f.close() self.failUnless(passed) if __name__ == '__main__': unittest.main() #a small test script, starts by reading sys.argv[1] #print 'rval', rval.shape, rval.size if 0: write(f, rval) print '' f.close() f = file('/tmp/some_mat', 'r'); rval2 = read(f) #load from file handle print 'rval2', rval2.shape, rval2.size assert rval.dtype == rval2.dtype assert rval.shape == rval2.shape assert numpy.all(rval == rval2) print 'ok' def _unused(): f.seek(0,2) #seek to end f_len = f.tell() f.seek(f_data_start,0) #seek back to where we were if debug: print 'length:', f_len f_data_bytes = (f_len - f_data_start) if debug: print 'data bytes according to header: ', dim_size * elsize if debug: print 'data bytes according to file : ', f_data_bytes if debug: print 'reading data...' sys.stdout.flush() def read_ndarray(f, dim, dtype): return numpy.fromfile(f, dtype=dtype, count=_prod(dim)).reshape(dim)