Mercurial > pylearn
diff _test_filetensor.py @ 375:12ce29abf27d
Automated merge with http://lgcm.iro.umontreal.ca/hg/pylearn
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 16 Jun 2008 17:47:36 -0400 |
parents | b55c829695f1 |
children | 040cb796f4e0 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/_test_filetensor.py Mon Jun 16 17:47:36 2008 -0400 @@ -0,0 +1,116 @@ +from filetensor import * +import filetensor + +import unittest +import os + +class T(unittest.TestCase): + fname = '/tmp/some_mat' + + def setUp(self): + #TODO: test that /tmp/some_mat does not exist + try: + os.stat(self.fname) + except OSError: + return #assume file was not found + raise Exception('autotest file "%s" exists!' % self.fname) + + def tearDown(self): + os.remove(self.fname) + + def test_file(self): + gen = numpy.random.rand(1) + f = file(self.fname, 'w'); + write(f, gen) + f.flush() + f = file(self.fname, 'r'); + mat = read(f, None, debug=False) #load from filename + self.failUnless(gen.shape == mat.shape) + self.failUnless(numpy.all(gen == mat)) + + def test_filename(self): + gen = numpy.random.rand(1) + write(self.fname, gen) + mat = read(self.fname, None, debug=False) #load from filename + self.failUnless(gen.shape == mat.shape) + self.failUnless(numpy.all(gen == mat)) + + def testNd(self): + """shape and values are stored correctly for tensors of rank 0 to 5""" + whole_shape = [5, 6, 7, 8, 9] + for i in xrange(5): + gen = numpy.asarray(numpy.random.rand(*whole_shape[:i])) + f = file(self.fname, 'w'); + write(f, gen) + f.flush() + f = file(self.fname, 'r'); + mat = read(f, None, debug=False) #load from filename + self.failUnless(gen.shape == mat.shape) + self.failUnless(numpy.all(gen == mat)) + + def test_dtypes(self): + """shape and values are stored correctly for all dtypes """ + for dtype in filetensor._dtype_magic: + gen = numpy.asarray( + numpy.random.rand(4, 5, 2, 1) * 100, + dtype=dtype) + f = file(self.fname, 'w'); + write(f, gen) + f.flush() + f = file(self.fname, 'r'); + mat = read(f, None, debug=False) #load from filename + self.failUnless(gen.dtype == mat.dtype) + self.failUnless(gen.shape == mat.shape) + self.failUnless(numpy.all(gen == mat)) + + def test_dtype_invalid(self): + gen = numpy.zeros((3,4), dtype='uint16') #an unsupported dtype + f = file(self.fname, 'w') + passed = False + try: + write(f, gen) + except TypeError, e: + if e[0].startswith('Invalid ndarray dtype'): + passed = True + f.close() + self.failUnless(passed) + + +if __name__ == '__main__': + unittest.main() + + #a small test script, starts by reading sys.argv[1] + #print 'rval', rval.shape, rval.size + + if 0: + write(f, rval) + print '' + f.close() + f = file('/tmp/some_mat', 'r'); + rval2 = read(f) #load from file handle + print 'rval2', rval2.shape, rval2.size + + assert rval.dtype == rval2.dtype + assert rval.shape == rval2.shape + assert numpy.all(rval == rval2) + print 'ok' + + def _unused(): + f.seek(0,2) #seek to end + f_len = f.tell() + f.seek(f_data_start,0) #seek back to where we were + + if debug: print 'length:', f_len + + + f_data_bytes = (f_len - f_data_start) + + if debug: print 'data bytes according to header: ', dim_size * elsize + if debug: print 'data bytes according to file : ', f_data_bytes + + if debug: print 'reading data...' + sys.stdout.flush() + + def read_ndarray(f, dim, dtype): + return numpy.fromfile(f, dtype=dtype, count=_prod(dim)).reshape(dim) +