Mercurial > pylearn
comparison test_filetensor.py @ 248:82ba488b2c24
polished filetensor a little
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 13:14:45 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
245:c702abb7f875 | 248:82ba488b2c24 |
---|---|
1 from filetensor import * | |
2 import filetensor | |
3 | |
4 import unittest | |
5 import os | |
6 | |
7 class T(unittest.TestCase): | |
8 fname = '/tmp/some_mat' | |
9 | |
10 def setUp(self): | |
11 #TODO: test that /tmp/some_mat does not exist | |
12 try: | |
13 os.stat(self.fname) | |
14 except OSError: | |
15 return #assume file was not found | |
16 raise Exception('autotest file "%s" exists!' % self.fname) | |
17 | |
18 def tearDown(self): | |
19 os.remove(self.fname) | |
20 | |
21 def test_file(self): | |
22 gen = numpy.random.rand(1) | |
23 f = file(self.fname, 'w'); | |
24 write(f, gen) | |
25 f.flush() | |
26 f = file(self.fname, 'r'); | |
27 mat = read(f, None, debug=False) #load from filename | |
28 self.failUnless(gen.shape == mat.shape) | |
29 self.failUnless(numpy.all(gen == mat)) | |
30 | |
31 def test_filename(self): | |
32 gen = numpy.random.rand(1) | |
33 write(self.fname, gen) | |
34 mat = read(self.fname, None, debug=False) #load from filename | |
35 self.failUnless(gen.shape == mat.shape) | |
36 self.failUnless(numpy.all(gen == mat)) | |
37 | |
38 def testNd(self): | |
39 """shape and values are stored correctly for tensors of rank 0 to 5""" | |
40 whole_shape = [5, 6, 7, 8, 9] | |
41 for i in xrange(5): | |
42 gen = numpy.asarray(numpy.random.rand(*whole_shape[:i])) | |
43 f = file(self.fname, 'w'); | |
44 write(f, gen) | |
45 f.flush() | |
46 f = file(self.fname, 'r'); | |
47 mat = read(f, None, debug=False) #load from filename | |
48 self.failUnless(gen.shape == mat.shape) | |
49 self.failUnless(numpy.all(gen == mat)) | |
50 | |
51 def test_dtypes(self): | |
52 """shape and values are stored correctly for all dtypes """ | |
53 for dtype in filetensor._dtype_magic: | |
54 gen = numpy.asarray( | |
55 numpy.random.rand(4, 5, 2, 1) * 100, | |
56 dtype=dtype) | |
57 f = file(self.fname, 'w'); | |
58 write(f, gen) | |
59 f.flush() | |
60 f = file(self.fname, 'r'); | |
61 mat = read(f, None, debug=False) #load from filename | |
62 self.failUnless(gen.dtype == mat.dtype) | |
63 self.failUnless(gen.shape == mat.shape) | |
64 self.failUnless(numpy.all(gen == mat)) | |
65 | |
66 def test_dtype_invalid(self): | |
67 gen = numpy.zeros((3,4), dtype='uint16') #an unsupported dtype | |
68 f = file(self.fname, 'w') | |
69 passed = False | |
70 try: | |
71 write(f, gen) | |
72 except TypeError, e: | |
73 if e[0].startswith('Invalid ndarray dtype'): | |
74 passed = True | |
75 f.close() | |
76 self.failUnless(passed) | |
77 | |
78 | |
79 if __name__ == '__main__': | |
80 unittest.main() | |
81 | |
82 #a small test script, starts by reading sys.argv[1] | |
83 #print 'rval', rval.shape, rval.size | |
84 | |
85 if 0: | |
86 write(f, rval) | |
87 print '' | |
88 f.close() | |
89 f = file('/tmp/some_mat', 'r'); | |
90 rval2 = read(f) #load from file handle | |
91 print 'rval2', rval2.shape, rval2.size | |
92 | |
93 assert rval.dtype == rval2.dtype | |
94 assert rval.shape == rval2.shape | |
95 assert numpy.all(rval == rval2) | |
96 print 'ok' | |
97 | |
98 def _unused(): | |
99 f.seek(0,2) #seek to end | |
100 f_len = f.tell() | |
101 f.seek(f_data_start,0) #seek back to where we were | |
102 | |
103 if debug: print 'length:', f_len | |
104 | |
105 | |
106 f_data_bytes = (f_len - f_data_start) | |
107 | |
108 if debug: print 'data bytes according to header: ', dim_size * elsize | |
109 if debug: print 'data bytes according to file : ', f_data_bytes | |
110 | |
111 if debug: print 'reading data...' | |
112 sys.stdout.flush() | |
113 | |
114 def read_ndarray(f, dim, dtype): | |
115 return numpy.fromfile(f, dtype=dtype, count=_prod(dim)).reshape(dim) | |
116 |