comparison filetensor.py @ 248:82ba488b2c24

polished filetensor a little
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 03 Jun 2008 13:14:45 -0400
parents 2b6656b2ef52
children 040cb796f4e0
comparison
equal deleted inserted replaced
245:c702abb7f875 248:82ba488b2c24
14 - for vector: rank=1, dimensions = [?, 1, 1] 14 - for vector: rank=1, dimensions = [?, 1, 1]
15 - for matrix: rank=2, dimensions = [?, ?, 1] 15 - for matrix: rank=2, dimensions = [?, ?, 1]
16 16
17 For rank >= 3, the number of dimensions matches the rank exactly. 17 For rank >= 3, the number of dimensions matches the rank exactly.
18 18
19
20 @todo: add complex type support
21
19 """ 22 """
20 import sys 23 import sys
21 import numpy 24 import numpy
22 25
23 def prod(lst): 26 def _prod(lst):
24 p = 1 27 p = 1
25 for l in lst: 28 for l in lst:
26 p *= l 29 p *= l
27 return p 30 return p
28 31
29 _magic_dtype = { 32 _magic_dtype = {
30 0x1E3D4C51 : ('float32', 4), 33 0x1E3D4C51 : ('float32', 4),
31 0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix? 34 #0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix?
32 0x1E3D4C53 : ('float64', 8), 35 0x1E3D4C53 : ('float64', 8),
33 0x1E3D4C54 : ('int32', 4), 36 0x1E3D4C54 : ('int32', 4),
34 0x1E3D4C55 : ('uint8', 1), 37 0x1E3D4C55 : ('uint8', 1),
35 0x1E3D4C56 : ('int16', 2), 38 0x1E3D4C56 : ('int16', 2),
36 } 39 }
37 _dtype_magic = { 40 _dtype_magic = {
38 'float32': 0x1E3D4C51, 41 'float32': 0x1E3D4C51,
39 'packed matrix': 0x1E3D4C52, 42 #'packed matrix': 0x1E3D4C52,
40 'float64': 0x1E3D4C53, 43 'float64': 0x1E3D4C53,
41 'int32': 0x1E3D4C54, 44 'int32': 0x1E3D4C54,
42 'uint8': 0x1E3D4C55, 45 'uint8': 0x1E3D4C55,
43 'int16': 0x1E3D4C56 46 'int16': 0x1E3D4C56
44 } 47 }
45
46 def _unused():
47 f.seek(0,2) #seek to end
48 f_len = f.tell()
49 f.seek(f_data_start,0) #seek back to where we were
50
51 if debug: print 'length:', f_len
52
53
54 f_data_bytes = (f_len - f_data_start)
55
56 if debug: print 'data bytes according to header: ', dim_size * elsize
57 if debug: print 'data bytes according to file : ', f_data_bytes
58
59 if debug: print 'reading data...'
60 sys.stdout.flush()
61
62 def _write_int32(f, i):
63 i_array = numpy.asarray(i, dtype='int32')
64 if 0: print 'writing int32', i, i_array
65 i_array.tofile(f)
66 def _read_int32(f):
67 s = f.read(4)
68 s_array = numpy.fromstring(s, dtype='int32')
69 return s_array.item()
70
71 def read_ndarray(f, dim, dtype):
72 return numpy.fromfile(f, dtype=dtype, count=prod(dim)).reshape(dim)
73 48
74 # 49 #
75 # TODO: implement item selection: 50 # TODO: implement item selection:
76 # e.g. load('some mat', subtensor=(:6, 2:5)) 51 # e.g. load('some mat', subtensor=(:6, 2:5))
77 # 52 #
92 67
93 Support for subtensors is currently spotty, so check the code to see if your 68 Support for subtensors is currently spotty, so check the code to see if your
94 particular type of subtensor is supported. 69 particular type of subtensor is supported.
95 70
96 """ 71 """
72 def _read_int32(f):
73 s = f.read(4)
74 s_array = numpy.fromstring(s, dtype='int32')
75 return s_array.item()
97 76
98 if isinstance(f, str): 77 if isinstance(f, str):
99 if debug: print 'f', f 78 if debug: print 'f', f
100 f = file(f, 'r') 79 f = file(f, 'r')
101 80
113 ndim = _read_int32(f) 92 ndim = _read_int32(f)
114 if debug: print 'header ndim', ndim 93 if debug: print 'header ndim', ndim
115 94
116 #what are the dimensions of the tensor? 95 #what are the dimensions of the tensor?
117 dim = numpy.fromfile(f, dtype='int32', count=max(ndim,3))[:ndim] 96 dim = numpy.fromfile(f, dtype='int32', count=max(ndim,3))[:ndim]
118 dim_size = prod(dim) 97 dim_size = _prod(dim)
119 if debug: print 'header dim', dim, dim_size 98 if debug: print 'header dim', dim, dim_size
120 99
121 rval = None 100 rval = None
122 if subtensor is None: 101 if subtensor is None:
123 rval = read_ndarray(f, dim, magic_t) 102 rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim)
124 elif isinstance(subtensor, slice): 103 elif isinstance(subtensor, slice):
125 if subtensor.step not in (None, 1): 104 if subtensor.step not in (None, 1):
126 raise NotImplementedError('slice with step', subtensor.step) 105 raise NotImplementedError('slice with step', subtensor.step)
127 if subtensor.start not in (None, 0): 106 if subtensor.start not in (None, 0):
128 bytes_per_row = prod(dim[1:]) * elsize 107 bytes_per_row = _prod(dim[1:]) * elsize
129 raise NotImplementedError('slice with start', subtensor.start) 108 raise NotImplementedError('slice with start', subtensor.start)
130 dim[0] = min(dim[0], subtensor.stop) 109 dim[0] = min(dim[0], subtensor.stop)
131 rval = read_ndarray(f, dim, magic_t) 110 rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim)
132 else: 111 else:
133 raise NotImplementedError('subtensor access not written yet:', subtensor) 112 raise NotImplementedError('subtensor access not written yet:', subtensor)
134 113
135 return rval 114 return rval
136 115
137 def write(f, mat): 116 def write(f, mat):
117 """Write a numpy.ndarray to file.
118
119 If 'f' is a string, then it will be interpreted as a filename. This filename
120 will be opened in 'w+' mode, and (automatically) closed at the end of the function.
121 """
122 def _write_int32(f, i):
123 i_array = numpy.asarray(i, dtype='int32')
124 if 0: print 'writing int32', i, i_array
125 i_array.tofile(f)
138 if isinstance(f, str): 126 if isinstance(f, str):
139 f = file(f, 'w') 127 f = file(f, 'w+')
140 128
141 _write_int32(f, _dtype_magic[str(mat.dtype)]) 129 try:
130 _write_int32(f, _dtype_magic[str(mat.dtype)])
131 except KeyError:
132 raise TypeError('Invalid ndarray dtype for filetensor format', mat.dtype)
133
142 _write_int32(f, len(mat.shape)) 134 _write_int32(f, len(mat.shape))
143 shape = mat.shape 135 shape = mat.shape
144 if len(shape) < 3: 136 if len(shape) < 3:
145 shape = list(shape) + [1] * (3 - len(shape)) 137 shape = list(shape) + [1] * (3 - len(shape))
146 print 'writing shape =', shape 138 if 0: print 'writing shape =', shape
147 for sh in shape: 139 for sh in shape:
148 _write_int32(f, sh) 140 _write_int32(f, sh)
149 mat.tofile(f) 141 mat.tofile(f)
150 142
151 if __name__ == '__main__':
152 #a small test script, starts by reading sys.argv[1]
153 rval = read(sys.argv[1], None, debug=True) #load from filename
154 print 'rval', rval.shape, rval.size
155
156 if 0:
157 f = file('/tmp/some_mat', 'w');
158 write(f, rval)
159 print ''
160 f.close()
161 f = file('/tmp/some_mat', 'r');
162 rval2 = read(f) #load from file handle
163 print 'rval2', rval2.shape, rval2.size
164
165 assert rval.dtype == rval2.dtype
166 assert rval.shape == rval2.shape
167 assert numpy.all(rval == rval2)
168 print 'ok'
169