comparison _test_filetensor.py @ 375:12ce29abf27d

Automated merge with http://lgcm.iro.umontreal.ca/hg/pylearn
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 16 Jun 2008 17:47:36 -0400
parents b55c829695f1
children 040cb796f4e0
comparison
equal deleted inserted replaced
374:2b16604ffad9 375:12ce29abf27d
1 from filetensor import *
2 import filetensor
3
4 import unittest
5 import os
6
7 class T(unittest.TestCase):
8 fname = '/tmp/some_mat'
9
10 def setUp(self):
11 #TODO: test that /tmp/some_mat does not exist
12 try:
13 os.stat(self.fname)
14 except OSError:
15 return #assume file was not found
16 raise Exception('autotest file "%s" exists!' % self.fname)
17
18 def tearDown(self):
19 os.remove(self.fname)
20
21 def test_file(self):
22 gen = numpy.random.rand(1)
23 f = file(self.fname, 'w');
24 write(f, gen)
25 f.flush()
26 f = file(self.fname, 'r');
27 mat = read(f, None, debug=False) #load from filename
28 self.failUnless(gen.shape == mat.shape)
29 self.failUnless(numpy.all(gen == mat))
30
31 def test_filename(self):
32 gen = numpy.random.rand(1)
33 write(self.fname, gen)
34 mat = read(self.fname, None, debug=False) #load from filename
35 self.failUnless(gen.shape == mat.shape)
36 self.failUnless(numpy.all(gen == mat))
37
38 def testNd(self):
39 """shape and values are stored correctly for tensors of rank 0 to 5"""
40 whole_shape = [5, 6, 7, 8, 9]
41 for i in xrange(5):
42 gen = numpy.asarray(numpy.random.rand(*whole_shape[:i]))
43 f = file(self.fname, 'w');
44 write(f, gen)
45 f.flush()
46 f = file(self.fname, 'r');
47 mat = read(f, None, debug=False) #load from filename
48 self.failUnless(gen.shape == mat.shape)
49 self.failUnless(numpy.all(gen == mat))
50
51 def test_dtypes(self):
52 """shape and values are stored correctly for all dtypes """
53 for dtype in filetensor._dtype_magic:
54 gen = numpy.asarray(
55 numpy.random.rand(4, 5, 2, 1) * 100,
56 dtype=dtype)
57 f = file(self.fname, 'w');
58 write(f, gen)
59 f.flush()
60 f = file(self.fname, 'r');
61 mat = read(f, None, debug=False) #load from filename
62 self.failUnless(gen.dtype == mat.dtype)
63 self.failUnless(gen.shape == mat.shape)
64 self.failUnless(numpy.all(gen == mat))
65
66 def test_dtype_invalid(self):
67 gen = numpy.zeros((3,4), dtype='uint16') #an unsupported dtype
68 f = file(self.fname, 'w')
69 passed = False
70 try:
71 write(f, gen)
72 except TypeError, e:
73 if e[0].startswith('Invalid ndarray dtype'):
74 passed = True
75 f.close()
76 self.failUnless(passed)
77
78
79 if __name__ == '__main__':
80 unittest.main()
81
82 #a small test script, starts by reading sys.argv[1]
83 #print 'rval', rval.shape, rval.size
84
85 if 0:
86 write(f, rval)
87 print ''
88 f.close()
89 f = file('/tmp/some_mat', 'r');
90 rval2 = read(f) #load from file handle
91 print 'rval2', rval2.shape, rval2.size
92
93 assert rval.dtype == rval2.dtype
94 assert rval.shape == rval2.shape
95 assert numpy.all(rval == rval2)
96 print 'ok'
97
98 def _unused():
99 f.seek(0,2) #seek to end
100 f_len = f.tell()
101 f.seek(f_data_start,0) #seek back to where we were
102
103 if debug: print 'length:', f_len
104
105
106 f_data_bytes = (f_len - f_data_start)
107
108 if debug: print 'data bytes according to header: ', dim_size * elsize
109 if debug: print 'data bytes according to file : ', f_data_bytes
110
111 if debug: print 'reading data...'
112 sys.stdout.flush()
113
114 def read_ndarray(f, dim, dtype):
115 return numpy.fromfile(f, dtype=dtype, count=_prod(dim)).reshape(dim)
116