Mercurial > pylearn
diff filetensor.py @ 248:82ba488b2c24
polished filetensor a little
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 13:14:45 -0400 |
parents | 2b6656b2ef52 |
children | 040cb796f4e0 |
line wrap: on
line diff
--- a/filetensor.py Mon Jun 02 17:09:58 2008 -0400 +++ b/filetensor.py Tue Jun 03 13:14:45 2008 -0400 @@ -16,11 +16,14 @@ For rank >= 3, the number of dimensions matches the rank exactly. + +@todo: add complex type support + """ import sys import numpy -def prod(lst): +def _prod(lst): p = 1 for l in lst: p *= l @@ -28,7 +31,7 @@ _magic_dtype = { 0x1E3D4C51 : ('float32', 4), - 0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix? + #0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix? 0x1E3D4C53 : ('float64', 8), 0x1E3D4C54 : ('int32', 4), 0x1E3D4C55 : ('uint8', 1), @@ -36,41 +39,13 @@ } _dtype_magic = { 'float32': 0x1E3D4C51, - 'packed matrix': 0x1E3D4C52, + #'packed matrix': 0x1E3D4C52, 'float64': 0x1E3D4C53, 'int32': 0x1E3D4C54, 'uint8': 0x1E3D4C55, 'int16': 0x1E3D4C56 } -def _unused(): - f.seek(0,2) #seek to end - f_len = f.tell() - f.seek(f_data_start,0) #seek back to where we were - - if debug: print 'length:', f_len - - - f_data_bytes = (f_len - f_data_start) - - if debug: print 'data bytes according to header: ', dim_size * elsize - if debug: print 'data bytes according to file : ', f_data_bytes - - if debug: print 'reading data...' - sys.stdout.flush() - -def _write_int32(f, i): - i_array = numpy.asarray(i, dtype='int32') - if 0: print 'writing int32', i, i_array - i_array.tofile(f) -def _read_int32(f): - s = f.read(4) - s_array = numpy.fromstring(s, dtype='int32') - return s_array.item() - -def read_ndarray(f, dim, dtype): - return numpy.fromfile(f, dtype=dtype, count=prod(dim)).reshape(dim) - # # TODO: implement item selection: # e.g. load('some mat', subtensor=(:6, 2:5)) @@ -94,6 +69,10 @@ particular type of subtensor is supported. """ + def _read_int32(f): + s = f.read(4) + s_array = numpy.fromstring(s, dtype='int32') + return s_array.item() if isinstance(f, str): if debug: print 'f', f @@ -115,55 +94,49 @@ #what are the dimensions of the tensor? dim = numpy.fromfile(f, dtype='int32', count=max(ndim,3))[:ndim] - dim_size = prod(dim) + dim_size = _prod(dim) if debug: print 'header dim', dim, dim_size rval = None if subtensor is None: - rval = read_ndarray(f, dim, magic_t) + rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim) elif isinstance(subtensor, slice): if subtensor.step not in (None, 1): raise NotImplementedError('slice with step', subtensor.step) if subtensor.start not in (None, 0): - bytes_per_row = prod(dim[1:]) * elsize + bytes_per_row = _prod(dim[1:]) * elsize raise NotImplementedError('slice with start', subtensor.start) dim[0] = min(dim[0], subtensor.stop) - rval = read_ndarray(f, dim, magic_t) + rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim) else: raise NotImplementedError('subtensor access not written yet:', subtensor) return rval def write(f, mat): + """Write a numpy.ndarray to file. + + If 'f' is a string, then it will be interpreted as a filename. This filename + will be opened in 'w+' mode, and (automatically) closed at the end of the function. + """ + def _write_int32(f, i): + i_array = numpy.asarray(i, dtype='int32') + if 0: print 'writing int32', i, i_array + i_array.tofile(f) if isinstance(f, str): - f = file(f, 'w') + f = file(f, 'w+') - _write_int32(f, _dtype_magic[str(mat.dtype)]) + try: + _write_int32(f, _dtype_magic[str(mat.dtype)]) + except KeyError: + raise TypeError('Invalid ndarray dtype for filetensor format', mat.dtype) + _write_int32(f, len(mat.shape)) shape = mat.shape if len(shape) < 3: shape = list(shape) + [1] * (3 - len(shape)) - print 'writing shape =', shape + if 0: print 'writing shape =', shape for sh in shape: _write_int32(f, sh) mat.tofile(f) -if __name__ == '__main__': - #a small test script, starts by reading sys.argv[1] - rval = read(sys.argv[1], None, debug=True) #load from filename - print 'rval', rval.shape, rval.size - - if 0: - f = file('/tmp/some_mat', 'w'); - write(f, rval) - print '' - f.close() - f = file('/tmp/some_mat', 'r'); - rval2 = read(f) #load from file handle - print 'rval2', rval2.shape, rval2.size - - assert rval.dtype == rval2.dtype - assert rval.shape == rval2.shape - assert numpy.all(rval == rval2) - print 'ok' -