view pmat.py @ 416:8849eba55520

Can now do minibatch update
author Joseph Turian <turian@iro.umontreal.ca>
date Fri, 11 Jul 2008 16:34:46 -0400
parents c2f17f231960
children
line wrap: on
line source

## Automatically adapted for numpy.numarray Jun 13, 2007 by python_numarray_to_numpy (-xsm)

# PMat.py
# Copyright (C) 2005 Pascal Vincent
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#   1. Redistributions of source code must retain the above copyright
#      notice, this list of conditions and the following disclaimer.
#
#   2. Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#
#   3. The name of the authors may not be used to endorse or promote
#      products derived from this software without specific prior written
#      permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR
#  IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
#  OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
#  NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
#  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
#  TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
#  PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
#  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
#  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#  This file is part of the PLearn library. For more information on the PLearn
#  library, go to the PLearn Web site at www.plearn.org


# Author: Pascal Vincent

#import numarray, sys, os, os.path
import numpy.numarray, sys, os, os.path
import fpconst

def array_columns( a, cols ):
    indices = None
    if isinstance( cols, int ):
        indices = [ cols ]
    elif isinstance( cols, slice ):
        #print cols
        indices = range( *cols.indices(cols.stop) )
    else:
        indices = list( cols )            

    return numpy.numarray.take(a, indices, axis=1)

def load_pmat_as_array(fname):
    s = file(fname,'rb').read()
    formatstr = s[0:64]
    datastr = s[64:]
    structuretype, l, w, data_type, endianness = formatstr.split()

    if data_type=='DOUBLE':
        elemtype = 'd'
    elif data_type=='FLOAT':
        elemtype = 'f'
    else:
        raise ValueError('Invalid data type in file header: '+data_type)

    if endianness=='LITTLE_ENDIAN':
        byteorder = 'little'
    elif endianness=='BIG_ENDIAN':
        byteorder = 'big'
    else:
        raise ValueError('Invalid endianness in file header: '+endianness)

    l = int(l)
    w = int(w)
    X = numpy.numarray.fromstring(datastr,elemtype, shape=(l,w) )
    if byteorder!=sys.byteorder:
        X.byteswap(True)
    return X

def load_pmat_as_array_dataset(fname):
    import dataset,lookup_list
    
    #load the pmat as array
    a=load_pmat_as_array(fname)
    
    #load the fieldnames
    fieldnames = []
    fieldnamefile = os.path.join(fname+'.metadata','fieldnames')
    if os.path.isfile(fieldnamefile):
        f = open(fieldnamefile)
        for row in f:
            row = row.split()
            if len(row)>0:
                fieldnames.append(row[0])
        f.close()
    else:
        self.fieldnames = [ "field_"+str(i) for i in range(a.shape[1]) ]

    return dataset.ArrayDataSet(a,lookup_list.LookupList(fieldnames,[x for x in range(a.shape[1])]))

def load_amat_as_array_dataset(fname):
    import dataset,lookup_list
    
    #load the amat as array
    (a,fieldnames)=readAMat(fname)
    
    #load the fieldnames
    if len(fieldnames)==0:
        self.fieldnames = [ "field_"+str(i) for i in range(a.shape[1]) ]

    return dataset.ArrayDataSet(a,lookup_list.LookupList(fieldnames,[x for x in range(a.shape[1])]))

def save_array_dataset_as_pmat(fname,ds):
    ar=ds.data
    save_array_as_pmat(fname,ar,ds.fieldNames())

def save_array_as_pmat( fname, ar, fieldnames=[] ):
    s = file(fname,'wb')
    
    length, width = ar.shape
    if fieldnames:
        assert len(fieldnames) == width
        metadatadir = fname+'.metadata'
        if not os.path.isdir(metadatadir):
            os.mkdir(metadatadir)
        fieldnamefile = os.path.join(metadatadir,'fieldnames')
        f = open(fieldnamefile,'wb')
        for name in fieldnames:
            f.write(name+'\t0\n')
        f.close()
    
    header = 'MATRIX ' + str(length) + ' ' + str(width) + ' '
    if ar.dtype.char=='d':
        header += 'DOUBLE '
        elemsize = 8

    elif ar.dtype.char=='f':
        header += 'FLOAT '
        elemsize = 4

    else:
        raise TypeError('Unsupported typecode: %s' % ar.dtype.char)

    rowsize = elemsize*width

    if sys.byteorder=='little':
        header += 'LITTLE_ENDIAN '
    elif sys.byteorder=='big':
        header += 'BIG_ENDIAN '
    else:
        raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder))

    header += ' '*(63-len(header))+'\n'
    s.write( header )
    s.write( ar.tostring() )
    s.close()    


#######  Iterators  ###########################################################

class VMatIt:
    def __init__(self, vmat):                
        self.vmat = vmat
        self.cur_row = 0

    def __iter__(self):
        return self

    def next(self):
        if self.cur_row==self.vmat.length:
            raise StopIteration
        row = self.vmat.getRow(self.cur_row)
        self.cur_row += 1
        return row    

class ColumnIt:
    def __init__(self, vmat, col):                
        self.vmat = vmat
        self.col  = col
        self.cur_row = 0

    def __iter__(self):
        return self

    def next(self):
        if self.cur_row==self.vmat.length:
            raise StopIteration
        val = self.vmat[self.cur_row, self.col]
        self.cur_row += 1
        return val

#######  VMat classes  ########################################################

class VMat:
    def __iter__(self):
        return VMatIt(self)
    
    def __getitem__( self, key ):
        if isinstance( key, slice ):
            start, stop, step = key.start, key.stop, key.step
            if step!=None:
                raise IndexError('Extended slice with step not currently supported')

            if start is None:
                start = 0

            l = self.length
            if stop is None or stop > l:
                stop = l

            return self.getRows(start,stop-start)
        
        elif isinstance( key, tuple ):
            # Basically returns a SubVMatrix
            assert len(key) == 2
            rows = self.__getitem__( key[0] )

            shape = rows.shape                       
            if len(shape) == 1:
                return rows[ key[1] ]

            cols = key[1]
            if isinstance(cols, slice):
                start, stop, step = cols.start, cols.stop, cols.step
                if start is None:
                    start = 0

                if stop is None:
                    stop = self.width
                elif stop < 0:
                    stop = self.width+stop

                cols = slice(start, stop, step)

            return array_columns(rows, cols)

        elif isinstance( key, str ):
            # The key is considered to be a fieldname and a column is
            # returned.
            try:
                return array_columns( self.getRows(0,self.length),
                                      self.fieldnames.index(key)  )
            except ValueError:
                print >>sys.stderr, "Key is '%s' while fieldnames are:" % key
                print >>sys.stderr, self.fieldnames
                raise
                
        else:
            if key<0: key+=self.length
            return self.getRow(key)

    def getFieldIndex(self, fieldname):
        try:
            return self.fieldnames.index(fieldname)
        except ValueError:
            raise ValueError( "VMat has no field named %s. Field names: %s"
                              %(fieldname, ','.join(self.fieldnames)) )

class PMat( VMat ):

    def __init__(self, fname, openmode='r', fieldnames=[], elemtype='d',
                 inputsize=-1, targetsize=-1, weightsize=-1, array = None):
        self.fname = fname
        self.inputsize = inputsize
        self.targetsize = targetsize
        self.weightsize = weightsize
        if openmode=='r':
            self.f = open(fname,'rb')
            self.read_and_parse_header()
            self.load_fieldnames()            
                
        elif openmode=='w':
            self.f = open(fname,'w+b')
            self.fieldnames = fieldnames
            self.save_fieldnames()
            self.length = 0
            self.width = len(fieldnames)
            self.elemtype = elemtype
            self.swap_bytes = False
            self.write_header()            
            
        elif openmode=='a':
            self.f = open(fname,'r+b')
            self.read_and_parse_header()
            self.load_fieldnames()

        else:
            raise ValueError("Currently only supported openmodes are 'r', 'w' and 'a': "+repr(openmode)+" is not supported")

        if array is not None:
            shape  = array.shape
            if len(shape) == 1:
                row_format = lambda r: [ r ]
            elif len(shape) == 2:
                row_format = lambda r: r

            for row in array:
                self.appendRow( row_format(row) )

    def __del__(self):
        self.close()

    def write_header(self):
        header = 'MATRIX ' + str(self.length) + ' ' + str(self.width) + ' '

        if self.elemtype=='d':
            header += 'DOUBLE '
            self.elemsize = 8
        elif self.elemtype=='f':
            header += 'FLOAT '
            self.elemsize = 4
        else:
            raise TypeError('Unsupported elemtype: '+repr(elemtype))
        self.rowsize = self.elemsize*self.width

        if sys.byteorder=='little':
            header += 'LITTLE_ENDIAN '
        elif sys.byteorder=='big':
            header += 'BIG_ENDIAN '
        else:
            raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder))
        
        header += ' '*(63-len(header))+'\n'

        self.f.seek(0)
        self.f.write(header)
        
    def read_and_parse_header(self):        
        header = self.f.read(64)        
        mat_type, l, w, data_type, endianness = header.split()
        if mat_type!='MATRIX':
            raise ValueError('Invalid file header (should start with MATRIX)')
        self.length = int(l)
        self.width = int(w)
        if endianness=='LITTLE_ENDIAN':
            byteorder = 'little'
        elif endianness=='BIG_ENDIAN':
            byteorder = 'big'
        else:
            raise ValueError('Invalid endianness in file header: '+endianness)
        self.swap_bytes = (byteorder!=sys.byteorder)

        if data_type=='DOUBLE':
            self.elemtype = 'd'
            self.elemsize = 8
        elif data_type=='FLOAT':
            self.elemtype = 'f'
            self.elemsize = 4
        else:
            raise ValueError('Invalid data type in file header: '+data_type)
        self.rowsize = self.elemsize*self.width

    def load_fieldnames(self):
        self.fieldnames = []
        fieldnamefile = os.path.join(self.fname+'.metadata','fieldnames')
        if os.path.isfile(fieldnamefile):
            f = open(fieldnamefile)
            for row in f:
                row = row.split()
                if len(row)>0:
                    self.fieldnames.append(row[0])
            f.close()
        else:
            self.fieldnames = [ "field_"+str(i) for i in range(self.width) ]
            
    def save_fieldnames(self):
        metadatadir = self.fname+'.metadata'
        if not os.path.isdir(metadatadir):
            os.mkdir(metadatadir)
        fieldnamefile = os.path.join(metadatadir,'fieldnames')
        f = open(fieldnamefile,'wb')
        for name in self.fieldnames:
            f.write(name+'\t0\n')
        f.close()

    def getRow(self,i):
        if i<0 or i>=self.length:
            raise IndexError('PMat index out of range')
        self.f.seek(64+i*self.rowsize)
        data = self.f.read(self.rowsize)
        ar = numpy.numarray.fromstring(data, self.elemtype, (self.width,))
        if self.swap_bytes:
            ar.byteswap(True)
        return ar

    def getRows(self,i,l):
        if i<0 or l<0 or i+l>self.length:
            raise IndexError('PMat index out of range')
        self.f.seek(64+i*self.rowsize)
        data = self.f.read(l*self.rowsize)
        ar = numpy.numarray.fromstring(data, self.elemtype, (l,self.width))
        if self.swap_bytes:
            ar.byteswap(True)
        return ar

    def checkzerorow(self,i):
        if i<0 or i>self.length:
            raise IndexError('PMat index out of range')
        self.f.seek(64+i*self.rowsize)
        data = self.f.read(self.rowsize)
        ar = numpy.numarray.fromstring(data, self.elemtype, (len(data)/self.elemsize,))
        if self.swap_bytes:
            ar.byteswap(True)
        for elem in ar:
            if elem!=0:
                return False
        return True
    
    def putRow(self,i,row):
        if i<0 or i>=self.length:
            raise IndexError('PMat index out of range')
        if len(row)!=self.width:
            raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')')
        if i<0 or i>=self.length:
            raise IndexError
        if self.swap_bytes: # must make a copy and swap bytes
            ar = numpy.numarray.numarray(row,type=self.elemtype)
            ar.byteswap(True)
        else: # asarray makes a copy if not already a numarray of the right type
            ar = numpy.numarray.asarray(row,type=self.elemtype)
        self.f.seek(64+i*self.rowsize)
        self.f.write(ar.tostring())

    def appendRow(self,row):
        if len(row)!=self.width:
            raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')')
        if self.swap_bytes: # must make a copy and swap bytes
            ar = numpy.numarray.numarray(row,type=self.elemtype)
            ar.byteswap(True)
        else: # asarray makes a copy if not already a numarray of the right type
            ar = numpy.numarray.asarray(row,type=self.elemtype)

        self.f.seek(64+self.length*self.rowsize)
        self.f.write(ar.tostring())
        self.length += 1
        self.write_header() # update length in header

    def flush(self):
        self.f.flush()

    def close(self):
        if hasattr(self, 'f'):
            self.f.close()

    def append(self,row):
        self.appendRow(row)

    def __setitem__(self, i, row):
        l = self.length
        if i<0: i+=l
        self.putRow(i,row)

    def __len__(self):
        return self.length



#copied from PLEARNDIR:python_modules/plearn/vmat/readAMat.py
def safefloat(str):
    """Convert the given string to its float value. It is 'safe' in the sense
    that missing values ('nan') will be properly converted to the corresponding
    float value under all platforms, contrarily to 'float(str)'.
    """
    if str.lower() == 'nan':
        return fpconst.NaN
    else:
        return float(str)

#copied from PLEARNDIR:python_modules/plearn/vmat/readAMat.py
def readAMat(amatname):
    """Read a PLearn .amat file and return it as a numarray Array.

    Return a tuple, with as the first argument the array itself, and as
    the second argument the fieldnames (list of strings).
    """
    ### NOTE: this version is much faster than first creating the array and
    ### updating each row as it is read...  Bizarrely enough
    f = open(amatname)
    a = []
    fieldnames = []
    for line in f:
        if line.startswith("#size:"):
            (length,width) = line[6:].strip().split()
        elif line.startswith("#sizes:"):  # ignore input/target/weight/extra sizes
            continue

        elif line.startswith("#:"):
            fieldnames = line[2:].strip().split()
            pass
        elif not line.startswith('#'):
            # Add all non-comment lines.
            row = [ safefloat(x) for x in line.strip().split() ]
            if row:
                a.append(row)

    f.close()
    return numpy.numarray.array(a), fieldnames

            
if __name__ == '__main__':
    pmat = PMat( 'tmp.pmat', 'w', fieldnames=['F1', 'F2'] )
    pmat.append( [1, 2] )
    pmat.append( [3, 4] )
    pmat.close()

    pmat = PMat( 'tmp.pmat', 'r' )
    ar=load_pmat_as_array('tmp.pmat')
    ds=load_pmat_as_array_dataset('tmp.pmat')
    
    print "PMat",pmat
    print "PMat",pmat[:]
    print "array",ar
    print "ArrayDataSet",ds
    for i in ds:
        print i
    save_array_dataset_as_pmat("tmp2.pmat",ds)
    ds2=load_pmat_as_array_dataset('tmp2.pmat')
    for i in ds2:
        print i
    # print "+++ tmp.pmat contains: "
    # os.system( 'plearn vmat cat tmp.pmat' )
    import shutil
    for fname in ["tmp.pmat", "tmp2.pmat"]:
        os.remove( fname )
        if os.path.exists( fname+'.metadata' ):
            shutil.rmtree( fname+'.metadata' )