Mercurial > pylearn
comparison filetensor.py @ 248:82ba488b2c24
polished filetensor a little
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 13:14:45 -0400 |
parents | 2b6656b2ef52 |
children | 040cb796f4e0 |
comparison
equal
deleted
inserted
replaced
245:c702abb7f875 | 248:82ba488b2c24 |
---|---|
14 - for vector: rank=1, dimensions = [?, 1, 1] | 14 - for vector: rank=1, dimensions = [?, 1, 1] |
15 - for matrix: rank=2, dimensions = [?, ?, 1] | 15 - for matrix: rank=2, dimensions = [?, ?, 1] |
16 | 16 |
17 For rank >= 3, the number of dimensions matches the rank exactly. | 17 For rank >= 3, the number of dimensions matches the rank exactly. |
18 | 18 |
19 | |
20 @todo: add complex type support | |
21 | |
19 """ | 22 """ |
20 import sys | 23 import sys |
21 import numpy | 24 import numpy |
22 | 25 |
23 def prod(lst): | 26 def _prod(lst): |
24 p = 1 | 27 p = 1 |
25 for l in lst: | 28 for l in lst: |
26 p *= l | 29 p *= l |
27 return p | 30 return p |
28 | 31 |
29 _magic_dtype = { | 32 _magic_dtype = { |
30 0x1E3D4C51 : ('float32', 4), | 33 0x1E3D4C51 : ('float32', 4), |
31 0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix? | 34 #0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix? |
32 0x1E3D4C53 : ('float64', 8), | 35 0x1E3D4C53 : ('float64', 8), |
33 0x1E3D4C54 : ('int32', 4), | 36 0x1E3D4C54 : ('int32', 4), |
34 0x1E3D4C55 : ('uint8', 1), | 37 0x1E3D4C55 : ('uint8', 1), |
35 0x1E3D4C56 : ('int16', 2), | 38 0x1E3D4C56 : ('int16', 2), |
36 } | 39 } |
37 _dtype_magic = { | 40 _dtype_magic = { |
38 'float32': 0x1E3D4C51, | 41 'float32': 0x1E3D4C51, |
39 'packed matrix': 0x1E3D4C52, | 42 #'packed matrix': 0x1E3D4C52, |
40 'float64': 0x1E3D4C53, | 43 'float64': 0x1E3D4C53, |
41 'int32': 0x1E3D4C54, | 44 'int32': 0x1E3D4C54, |
42 'uint8': 0x1E3D4C55, | 45 'uint8': 0x1E3D4C55, |
43 'int16': 0x1E3D4C56 | 46 'int16': 0x1E3D4C56 |
44 } | 47 } |
45 | |
46 def _unused(): | |
47 f.seek(0,2) #seek to end | |
48 f_len = f.tell() | |
49 f.seek(f_data_start,0) #seek back to where we were | |
50 | |
51 if debug: print 'length:', f_len | |
52 | |
53 | |
54 f_data_bytes = (f_len - f_data_start) | |
55 | |
56 if debug: print 'data bytes according to header: ', dim_size * elsize | |
57 if debug: print 'data bytes according to file : ', f_data_bytes | |
58 | |
59 if debug: print 'reading data...' | |
60 sys.stdout.flush() | |
61 | |
62 def _write_int32(f, i): | |
63 i_array = numpy.asarray(i, dtype='int32') | |
64 if 0: print 'writing int32', i, i_array | |
65 i_array.tofile(f) | |
66 def _read_int32(f): | |
67 s = f.read(4) | |
68 s_array = numpy.fromstring(s, dtype='int32') | |
69 return s_array.item() | |
70 | |
71 def read_ndarray(f, dim, dtype): | |
72 return numpy.fromfile(f, dtype=dtype, count=prod(dim)).reshape(dim) | |
73 | 48 |
74 # | 49 # |
75 # TODO: implement item selection: | 50 # TODO: implement item selection: |
76 # e.g. load('some mat', subtensor=(:6, 2:5)) | 51 # e.g. load('some mat', subtensor=(:6, 2:5)) |
77 # | 52 # |
92 | 67 |
93 Support for subtensors is currently spotty, so check the code to see if your | 68 Support for subtensors is currently spotty, so check the code to see if your |
94 particular type of subtensor is supported. | 69 particular type of subtensor is supported. |
95 | 70 |
96 """ | 71 """ |
72 def _read_int32(f): | |
73 s = f.read(4) | |
74 s_array = numpy.fromstring(s, dtype='int32') | |
75 return s_array.item() | |
97 | 76 |
98 if isinstance(f, str): | 77 if isinstance(f, str): |
99 if debug: print 'f', f | 78 if debug: print 'f', f |
100 f = file(f, 'r') | 79 f = file(f, 'r') |
101 | 80 |
113 ndim = _read_int32(f) | 92 ndim = _read_int32(f) |
114 if debug: print 'header ndim', ndim | 93 if debug: print 'header ndim', ndim |
115 | 94 |
116 #what are the dimensions of the tensor? | 95 #what are the dimensions of the tensor? |
117 dim = numpy.fromfile(f, dtype='int32', count=max(ndim,3))[:ndim] | 96 dim = numpy.fromfile(f, dtype='int32', count=max(ndim,3))[:ndim] |
118 dim_size = prod(dim) | 97 dim_size = _prod(dim) |
119 if debug: print 'header dim', dim, dim_size | 98 if debug: print 'header dim', dim, dim_size |
120 | 99 |
121 rval = None | 100 rval = None |
122 if subtensor is None: | 101 if subtensor is None: |
123 rval = read_ndarray(f, dim, magic_t) | 102 rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim) |
124 elif isinstance(subtensor, slice): | 103 elif isinstance(subtensor, slice): |
125 if subtensor.step not in (None, 1): | 104 if subtensor.step not in (None, 1): |
126 raise NotImplementedError('slice with step', subtensor.step) | 105 raise NotImplementedError('slice with step', subtensor.step) |
127 if subtensor.start not in (None, 0): | 106 if subtensor.start not in (None, 0): |
128 bytes_per_row = prod(dim[1:]) * elsize | 107 bytes_per_row = _prod(dim[1:]) * elsize |
129 raise NotImplementedError('slice with start', subtensor.start) | 108 raise NotImplementedError('slice with start', subtensor.start) |
130 dim[0] = min(dim[0], subtensor.stop) | 109 dim[0] = min(dim[0], subtensor.stop) |
131 rval = read_ndarray(f, dim, magic_t) | 110 rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim) |
132 else: | 111 else: |
133 raise NotImplementedError('subtensor access not written yet:', subtensor) | 112 raise NotImplementedError('subtensor access not written yet:', subtensor) |
134 | 113 |
135 return rval | 114 return rval |
136 | 115 |
137 def write(f, mat): | 116 def write(f, mat): |
117 """Write a numpy.ndarray to file. | |
118 | |
119 If 'f' is a string, then it will be interpreted as a filename. This filename | |
120 will be opened in 'w+' mode, and (automatically) closed at the end of the function. | |
121 """ | |
122 def _write_int32(f, i): | |
123 i_array = numpy.asarray(i, dtype='int32') | |
124 if 0: print 'writing int32', i, i_array | |
125 i_array.tofile(f) | |
138 if isinstance(f, str): | 126 if isinstance(f, str): |
139 f = file(f, 'w') | 127 f = file(f, 'w+') |
140 | 128 |
141 _write_int32(f, _dtype_magic[str(mat.dtype)]) | 129 try: |
130 _write_int32(f, _dtype_magic[str(mat.dtype)]) | |
131 except KeyError: | |
132 raise TypeError('Invalid ndarray dtype for filetensor format', mat.dtype) | |
133 | |
142 _write_int32(f, len(mat.shape)) | 134 _write_int32(f, len(mat.shape)) |
143 shape = mat.shape | 135 shape = mat.shape |
144 if len(shape) < 3: | 136 if len(shape) < 3: |
145 shape = list(shape) + [1] * (3 - len(shape)) | 137 shape = list(shape) + [1] * (3 - len(shape)) |
146 print 'writing shape =', shape | 138 if 0: print 'writing shape =', shape |
147 for sh in shape: | 139 for sh in shape: |
148 _write_int32(f, sh) | 140 _write_int32(f, sh) |
149 mat.tofile(f) | 141 mat.tofile(f) |
150 | 142 |
151 if __name__ == '__main__': | |
152 #a small test script, starts by reading sys.argv[1] | |
153 rval = read(sys.argv[1], None, debug=True) #load from filename | |
154 print 'rval', rval.shape, rval.size | |
155 | |
156 if 0: | |
157 f = file('/tmp/some_mat', 'w'); | |
158 write(f, rval) | |
159 print '' | |
160 f.close() | |
161 f = file('/tmp/some_mat', 'r'); | |
162 rval2 = read(f) #load from file handle | |
163 print 'rval2', rval2.shape, rval2.size | |
164 | |
165 assert rval.dtype == rval2.dtype | |
166 assert rval.shape == rval2.shape | |
167 assert numpy.all(rval == rval2) | |
168 print 'ok' | |
169 |