# HG changeset patch # User James Bergstra # Date 1212513285 14400 # Node ID 82ba488b2c241089408c2b037fd4ae17cf2936b8 # Parent c702abb7f87557ae97a5576f46d6850623a57d72 polished filetensor a little diff -r c702abb7f875 -r 82ba488b2c24 filetensor.py --- a/filetensor.py Mon Jun 02 17:09:58 2008 -0400 +++ b/filetensor.py Tue Jun 03 13:14:45 2008 -0400 @@ -16,11 +16,14 @@ For rank >= 3, the number of dimensions matches the rank exactly. + +@todo: add complex type support + """ import sys import numpy -def prod(lst): +def _prod(lst): p = 1 for l in lst: p *= l @@ -28,7 +31,7 @@ _magic_dtype = { 0x1E3D4C51 : ('float32', 4), - 0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix? + #0x1E3D4C52 : ('packed matrix', 0), #what is a packed matrix? 0x1E3D4C53 : ('float64', 8), 0x1E3D4C54 : ('int32', 4), 0x1E3D4C55 : ('uint8', 1), @@ -36,41 +39,13 @@ } _dtype_magic = { 'float32': 0x1E3D4C51, - 'packed matrix': 0x1E3D4C52, + #'packed matrix': 0x1E3D4C52, 'float64': 0x1E3D4C53, 'int32': 0x1E3D4C54, 'uint8': 0x1E3D4C55, 'int16': 0x1E3D4C56 } -def _unused(): - f.seek(0,2) #seek to end - f_len = f.tell() - f.seek(f_data_start,0) #seek back to where we were - - if debug: print 'length:', f_len - - - f_data_bytes = (f_len - f_data_start) - - if debug: print 'data bytes according to header: ', dim_size * elsize - if debug: print 'data bytes according to file : ', f_data_bytes - - if debug: print 'reading data...' - sys.stdout.flush() - -def _write_int32(f, i): - i_array = numpy.asarray(i, dtype='int32') - if 0: print 'writing int32', i, i_array - i_array.tofile(f) -def _read_int32(f): - s = f.read(4) - s_array = numpy.fromstring(s, dtype='int32') - return s_array.item() - -def read_ndarray(f, dim, dtype): - return numpy.fromfile(f, dtype=dtype, count=prod(dim)).reshape(dim) - # # TODO: implement item selection: # e.g. load('some mat', subtensor=(:6, 2:5)) @@ -94,6 +69,10 @@ particular type of subtensor is supported. """ + def _read_int32(f): + s = f.read(4) + s_array = numpy.fromstring(s, dtype='int32') + return s_array.item() if isinstance(f, str): if debug: print 'f', f @@ -115,55 +94,49 @@ #what are the dimensions of the tensor? dim = numpy.fromfile(f, dtype='int32', count=max(ndim,3))[:ndim] - dim_size = prod(dim) + dim_size = _prod(dim) if debug: print 'header dim', dim, dim_size rval = None if subtensor is None: - rval = read_ndarray(f, dim, magic_t) + rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim) elif isinstance(subtensor, slice): if subtensor.step not in (None, 1): raise NotImplementedError('slice with step', subtensor.step) if subtensor.start not in (None, 0): - bytes_per_row = prod(dim[1:]) * elsize + bytes_per_row = _prod(dim[1:]) * elsize raise NotImplementedError('slice with start', subtensor.start) dim[0] = min(dim[0], subtensor.stop) - rval = read_ndarray(f, dim, magic_t) + rval = numpy.fromfile(f, dtype=magic_t, count=_prod(dim)).reshape(dim) else: raise NotImplementedError('subtensor access not written yet:', subtensor) return rval def write(f, mat): + """Write a numpy.ndarray to file. + + If 'f' is a string, then it will be interpreted as a filename. This filename + will be opened in 'w+' mode, and (automatically) closed at the end of the function. + """ + def _write_int32(f, i): + i_array = numpy.asarray(i, dtype='int32') + if 0: print 'writing int32', i, i_array + i_array.tofile(f) if isinstance(f, str): - f = file(f, 'w') + f = file(f, 'w+') - _write_int32(f, _dtype_magic[str(mat.dtype)]) + try: + _write_int32(f, _dtype_magic[str(mat.dtype)]) + except KeyError: + raise TypeError('Invalid ndarray dtype for filetensor format', mat.dtype) + _write_int32(f, len(mat.shape)) shape = mat.shape if len(shape) < 3: shape = list(shape) + [1] * (3 - len(shape)) - print 'writing shape =', shape + if 0: print 'writing shape =', shape for sh in shape: _write_int32(f, sh) mat.tofile(f) -if __name__ == '__main__': - #a small test script, starts by reading sys.argv[1] - rval = read(sys.argv[1], None, debug=True) #load from filename - print 'rval', rval.shape, rval.size - - if 0: - f = file('/tmp/some_mat', 'w'); - write(f, rval) - print '' - f.close() - f = file('/tmp/some_mat', 'r'); - rval2 = read(f) #load from file handle - print 'rval2', rval2.shape, rval2.size - - assert rval.dtype == rval2.dtype - assert rval.shape == rval2.shape - assert numpy.all(rval == rval2) - print 'ok' - diff -r c702abb7f875 -r 82ba488b2c24 test_filetensor.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test_filetensor.py Tue Jun 03 13:14:45 2008 -0400 @@ -0,0 +1,116 @@ +from filetensor import * +import filetensor + +import unittest +import os + +class T(unittest.TestCase): + fname = '/tmp/some_mat' + + def setUp(self): + #TODO: test that /tmp/some_mat does not exist + try: + os.stat(self.fname) + except OSError: + return #assume file was not found + raise Exception('autotest file "%s" exists!' % self.fname) + + def tearDown(self): + os.remove(self.fname) + + def test_file(self): + gen = numpy.random.rand(1) + f = file(self.fname, 'w'); + write(f, gen) + f.flush() + f = file(self.fname, 'r'); + mat = read(f, None, debug=False) #load from filename + self.failUnless(gen.shape == mat.shape) + self.failUnless(numpy.all(gen == mat)) + + def test_filename(self): + gen = numpy.random.rand(1) + write(self.fname, gen) + mat = read(self.fname, None, debug=False) #load from filename + self.failUnless(gen.shape == mat.shape) + self.failUnless(numpy.all(gen == mat)) + + def testNd(self): + """shape and values are stored correctly for tensors of rank 0 to 5""" + whole_shape = [5, 6, 7, 8, 9] + for i in xrange(5): + gen = numpy.asarray(numpy.random.rand(*whole_shape[:i])) + f = file(self.fname, 'w'); + write(f, gen) + f.flush() + f = file(self.fname, 'r'); + mat = read(f, None, debug=False) #load from filename + self.failUnless(gen.shape == mat.shape) + self.failUnless(numpy.all(gen == mat)) + + def test_dtypes(self): + """shape and values are stored correctly for all dtypes """ + for dtype in filetensor._dtype_magic: + gen = numpy.asarray( + numpy.random.rand(4, 5, 2, 1) * 100, + dtype=dtype) + f = file(self.fname, 'w'); + write(f, gen) + f.flush() + f = file(self.fname, 'r'); + mat = read(f, None, debug=False) #load from filename + self.failUnless(gen.dtype == mat.dtype) + self.failUnless(gen.shape == mat.shape) + self.failUnless(numpy.all(gen == mat)) + + def test_dtype_invalid(self): + gen = numpy.zeros((3,4), dtype='uint16') #an unsupported dtype + f = file(self.fname, 'w') + passed = False + try: + write(f, gen) + except TypeError, e: + if e[0].startswith('Invalid ndarray dtype'): + passed = True + f.close() + self.failUnless(passed) + + +if __name__ == '__main__': + unittest.main() + + #a small test script, starts by reading sys.argv[1] + #print 'rval', rval.shape, rval.size + + if 0: + write(f, rval) + print '' + f.close() + f = file('/tmp/some_mat', 'r'); + rval2 = read(f) #load from file handle + print 'rval2', rval2.shape, rval2.size + + assert rval.dtype == rval2.dtype + assert rval.shape == rval2.shape + assert numpy.all(rval == rval2) + print 'ok' + + def _unused(): + f.seek(0,2) #seek to end + f_len = f.tell() + f.seek(f_data_start,0) #seek back to where we were + + if debug: print 'length:', f_len + + + f_data_bytes = (f_len - f_data_start) + + if debug: print 'data bytes according to header: ', dim_size * elsize + if debug: print 'data bytes according to file : ', f_data_bytes + + if debug: print 'reading data...' + sys.stdout.flush() + + def read_ndarray(f, dim, dtype): + return numpy.fromfile(f, dtype=dtype, count=_prod(dim)).reshape(dim) +