changeset 119:7ffecde9dadc

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Wed, 07 May 2008 15:08:18 -0400
parents 9330d941fa1f (diff) d0a1bd0378c6 (current diff)
children 5fa46297191b
files
diffstat 2 files changed, 479 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pmat.py	Wed May 07 15:08:18 2008 -0400
@@ -0,0 +1,470 @@
+## Automatically adapted for numpy.numarray Jun 13, 2007 by python_numarray_to_numpy (-xsm)
+
+# PMat.py
+# Copyright (C) 2005 Pascal Vincent
+#
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions are met:
+#
+#   1. Redistributions of source code must retain the above copyright
+#      notice, this list of conditions and the following disclaimer.
+#
+#   2. Redistributions in binary form must reproduce the above copyright
+#      notice, this list of conditions and the following disclaimer in the
+#      documentation and/or other materials provided with the distribution.
+#
+#   3. The name of the authors may not be used to endorse or promote
+#      products derived from this software without specific prior written
+#      permission.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR
+#  IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
+#  OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
+#  NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+#  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+#  TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+#  PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+#  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+#  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#  This file is part of the PLearn library. For more information on the PLearn
+#  library, go to the PLearn Web site at www.plearn.org
+
+
+# Author: Pascal Vincent
+
+#import numarray, sys, os, os.path
+import numpy.numarray, sys, os, os.path
+
+def array_columns( a, cols ):
+    indices = None
+    if isinstance( cols, int ):
+        indices = [ cols ]
+    elif isinstance( cols, slice ):
+        #print cols
+        indices = range( *cols.indices(cols.stop) )
+    else:
+        indices = list( cols )            
+
+    return numpy.numarray.take(a, indices, axis=1)
+
+def load_pmat_as_array(fname):
+    s = file(fname,'rb').read()
+    formatstr = s[0:64]
+    datastr = s[64:]
+    structuretype, l, w, data_type, endianness = formatstr.split()
+
+    if data_type=='DOUBLE':
+        elemtype = 'd'
+    elif data_type=='FLOAT':
+        elemtype = 'f'
+    else:
+        raise ValueError('Invalid data type in file header: '+data_type)
+
+    if endianness=='LITTLE_ENDIAN':
+        byteorder = 'little'
+    elif endianness=='BIG_ENDIAN':
+        byteorder = 'big'
+    else:
+        raise ValueError('Invalid endianness in file header: '+endianness)
+
+    l = int(l)
+    w = int(w)
+    X = numpy.numarray.fromstring(datastr,elemtype, shape=(l,w) )
+    if byteorder!=sys.byteorder:
+        X.byteswap(True)
+    return X
+
+def load_pmat_as_array_dataset(fname):
+    import dataset,lookup_list
+    
+    #load the pmat as array
+    a=load_pmat_as_array(fname)
+    
+    #load the fieldnames
+    fieldnames = []
+    fieldnamefile = os.path.join(fname+'.metadata','fieldnames')
+    if os.path.isfile(fieldnamefile):
+        f = open(fieldnamefile)
+        for row in f:
+            row = row.split()
+            if len(row)>0:
+                fieldnames.append(row[0])
+        f.close()
+    else:
+        self.fieldnames = [ "field_"+str(i) for i in range(a.shape[1]) ]
+
+    return dataset.ArrayDataSet(a,lookup_list.LookupList(fieldnames,[x for x in range(a.shape[1])]))
+
+def save_array_dataset_as_pmat(fname,ds):
+    ar=ds.data
+    save_array_as_pmat(fname,ar,ds.fieldNames())
+
+def save_array_as_pmat( fname, ar, fieldnames=[] ):
+    s = file(fname,'wb')
+    
+    length, width = ar.shape
+    if fieldnames:
+        assert len(fieldnames) == width
+        metadatadir = fname+'.metadata'
+        if not os.path.isdir(metadatadir):
+            os.mkdir(metadatadir)
+        fieldnamefile = os.path.join(metadatadir,'fieldnames')
+        f = open(fieldnamefile,'wb')
+        for name in fieldnames:
+            f.write(name+'\t0\n')
+        f.close()
+    
+    header = 'MATRIX ' + str(length) + ' ' + str(width) + ' '
+    if ar.dtype.char=='d':
+        header += 'DOUBLE '
+        elemsize = 8
+
+    elif ar.dtype.char=='f':
+        header += 'FLOAT '
+        elemsize = 4
+
+    else:
+        raise TypeError('Unsupported typecode: %s' % ar.dtype.char)
+
+    rowsize = elemsize*width
+
+    if sys.byteorder=='little':
+        header += 'LITTLE_ENDIAN '
+    elif sys.byteorder=='big':
+        header += 'BIG_ENDIAN '
+    else:
+        raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder))
+
+    header += ' '*(63-len(header))+'\n'
+    s.write( header )
+    s.write( ar.tostring() )
+    s.close()    
+
+
+#######  Iterators  ###########################################################
+
+class VMatIt:
+    def __init__(self, vmat):                
+        self.vmat = vmat
+        self.cur_row = 0
+
+    def __iter__(self):
+        return self
+
+    def next(self):
+        if self.cur_row==self.vmat.length:
+            raise StopIteration
+        row = self.vmat.getRow(self.cur_row)
+        self.cur_row += 1
+        return row    
+
+class ColumnIt:
+    def __init__(self, vmat, col):                
+        self.vmat = vmat
+        self.col  = col
+        self.cur_row = 0
+
+    def __iter__(self):
+        return self
+
+    def next(self):
+        if self.cur_row==self.vmat.length:
+            raise StopIteration
+        val = self.vmat[self.cur_row, self.col]
+        self.cur_row += 1
+        return val
+
+#######  VMat classes  ########################################################
+
+class VMat:
+    def __iter__(self):
+        return VMatIt(self)
+    
+    def __getitem__( self, key ):
+        if isinstance( key, slice ):
+            start, stop, step = key.start, key.stop, key.step
+            if step!=None:
+                raise IndexError('Extended slice with step not currently supported')
+
+            if start is None:
+                start = 0
+
+            l = self.length
+            if stop is None or stop > l:
+                stop = l
+
+            return self.getRows(start,stop-start)
+        
+        elif isinstance( key, tuple ):
+            # Basically returns a SubVMatrix
+            assert len(key) == 2
+            rows = self.__getitem__( key[0] )
+
+            shape = rows.shape                       
+            if len(shape) == 1:
+                return rows[ key[1] ]
+
+            cols = key[1]
+            if isinstance(cols, slice):
+                start, stop, step = cols.start, cols.stop, cols.step
+                if start is None:
+                    start = 0
+
+                if stop is None:
+                    stop = self.width
+                elif stop < 0:
+                    stop = self.width+stop
+
+                cols = slice(start, stop, step)
+
+            return array_columns(rows, cols)
+
+        elif isinstance( key, str ):
+            # The key is considered to be a fieldname and a column is
+            # returned.
+            try:
+                return array_columns( self.getRows(0,self.length),
+                                      self.fieldnames.index(key)  )
+            except ValueError:
+                print >>sys.stderr, "Key is '%s' while fieldnames are:" % key
+                print >>sys.stderr, self.fieldnames
+                raise
+                
+        else:
+            if key<0: key+=self.length
+            return self.getRow(key)
+
+    def getFieldIndex(self, fieldname):
+        try:
+            return self.fieldnames.index(fieldname)
+        except ValueError:
+            raise ValueError( "VMat has no field named %s. Field names: %s"
+                              %(fieldname, ','.join(self.fieldnames)) )
+
+class PMat( VMat ):
+
+    def __init__(self, fname, openmode='r', fieldnames=[], elemtype='d',
+                 inputsize=-1, targetsize=-1, weightsize=-1, array = None):
+        self.fname = fname
+        self.inputsize = inputsize
+        self.targetsize = targetsize
+        self.weightsize = weightsize
+        if openmode=='r':
+            self.f = open(fname,'rb')
+            self.read_and_parse_header()
+            self.load_fieldnames()            
+                
+        elif openmode=='w':
+            self.f = open(fname,'w+b')
+            self.fieldnames = fieldnames
+            self.save_fieldnames()
+            self.length = 0
+            self.width = len(fieldnames)
+            self.elemtype = elemtype
+            self.swap_bytes = False
+            self.write_header()            
+            
+        elif openmode=='a':
+            self.f = open(fname,'r+b')
+            self.read_and_parse_header()
+            self.load_fieldnames()
+
+        else:
+            raise ValueError("Currently only supported openmodes are 'r', 'w' and 'a': "+repr(openmode)+" is not supported")
+
+        if array is not None:
+            shape  = array.shape
+            if len(shape) == 1:
+                row_format = lambda r: [ r ]
+            elif len(shape) == 2:
+                row_format = lambda r: r
+
+            for row in array:
+                self.appendRow( row_format(row) )
+
+    def __del__(self):
+        self.close()
+
+    def write_header(self):
+        header = 'MATRIX ' + str(self.length) + ' ' + str(self.width) + ' '
+
+        if self.elemtype=='d':
+            header += 'DOUBLE '
+            self.elemsize = 8
+        elif self.elemtype=='f':
+            header += 'FLOAT '
+            self.elemsize = 4
+        else:
+            raise TypeError('Unsupported elemtype: '+repr(elemtype))
+        self.rowsize = self.elemsize*self.width
+
+        if sys.byteorder=='little':
+            header += 'LITTLE_ENDIAN '
+        elif sys.byteorder=='big':
+            header += 'BIG_ENDIAN '
+        else:
+            raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder))
+        
+        header += ' '*(63-len(header))+'\n'
+
+        self.f.seek(0)
+        self.f.write(header)
+        
+    def read_and_parse_header(self):        
+        header = self.f.read(64)        
+        mat_type, l, w, data_type, endianness = header.split()
+        if mat_type!='MATRIX':
+            raise ValueError('Invalid file header (should start with MATRIX)')
+        self.length = int(l)
+        self.width = int(w)
+        if endianness=='LITTLE_ENDIAN':
+            byteorder = 'little'
+        elif endianness=='BIG_ENDIAN':
+            byteorder = 'big'
+        else:
+            raise ValueError('Invalid endianness in file header: '+endianness)
+        self.swap_bytes = (byteorder!=sys.byteorder)
+
+        if data_type=='DOUBLE':
+            self.elemtype = 'd'
+            self.elemsize = 8
+        elif data_type=='FLOAT':
+            self.elemtype = 'f'
+            self.elemsize = 4
+        else:
+            raise ValueError('Invalid data type in file header: '+data_type)
+        self.rowsize = self.elemsize*self.width
+
+    def load_fieldnames(self):
+        self.fieldnames = []
+        fieldnamefile = os.path.join(self.fname+'.metadata','fieldnames')
+        if os.path.isfile(fieldnamefile):
+            f = open(fieldnamefile)
+            for row in f:
+                row = row.split()
+                if len(row)>0:
+                    self.fieldnames.append(row[0])
+            f.close()
+        else:
+            self.fieldnames = [ "field_"+str(i) for i in range(self.width) ]
+            
+    def save_fieldnames(self):
+        metadatadir = self.fname+'.metadata'
+        if not os.path.isdir(metadatadir):
+            os.mkdir(metadatadir)
+        fieldnamefile = os.path.join(metadatadir,'fieldnames')
+        f = open(fieldnamefile,'wb')
+        for name in self.fieldnames:
+            f.write(name+'\t0\n')
+        f.close()
+
+    def getRow(self,i):
+        if i<0 or i>=self.length:
+            raise IndexError('PMat index out of range')
+        self.f.seek(64+i*self.rowsize)
+        data = self.f.read(self.rowsize)
+        ar = numpy.numarray.fromstring(data, self.elemtype, (self.width,))
+        if self.swap_bytes:
+            ar.byteswap(True)
+        return ar
+
+    def getRows(self,i,l):
+        if i<0 or l<0 or i+l>self.length:
+            raise IndexError('PMat index out of range')
+        self.f.seek(64+i*self.rowsize)
+        data = self.f.read(l*self.rowsize)
+        ar = numpy.numarray.fromstring(data, self.elemtype, (l,self.width))
+        if self.swap_bytes:
+            ar.byteswap(True)
+        return ar
+
+    def checkzerorow(self,i):
+        if i<0 or i>self.length:
+            raise IndexError('PMat index out of range')
+        self.f.seek(64+i*self.rowsize)
+        data = self.f.read(self.rowsize)
+        ar = numpy.numarray.fromstring(data, self.elemtype, (len(data)/self.elemsize,))
+        if self.swap_bytes:
+            ar.byteswap(True)
+        for elem in ar:
+            if elem!=0:
+                return False
+        return True
+    
+    def putRow(self,i,row):
+        if i<0 or i>=self.length:
+            raise IndexError('PMat index out of range')
+        if len(row)!=self.width:
+            raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')')
+        if i<0 or i>=self.length:
+            raise IndexError
+        if self.swap_bytes: # must make a copy and swap bytes
+            ar = numpy.numarray.numarray(row,type=self.elemtype)
+            ar.byteswap(True)
+        else: # asarray makes a copy if not already a numarray of the right type
+            ar = numpy.numarray.asarray(row,type=self.elemtype)
+        self.f.seek(64+i*self.rowsize)
+        self.f.write(ar.tostring())
+
+    def appendRow(self,row):
+        if len(row)!=self.width:
+            raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')')
+        if self.swap_bytes: # must make a copy and swap bytes
+            ar = numpy.numarray.numarray(row,type=self.elemtype)
+            ar.byteswap(True)
+        else: # asarray makes a copy if not already a numarray of the right type
+            ar = numpy.numarray.asarray(row,type=self.elemtype)
+
+        self.f.seek(64+self.length*self.rowsize)
+        self.f.write(ar.tostring())
+        self.length += 1
+        self.write_header() # update length in header
+
+    def flush(self):
+        self.f.flush()
+
+    def close(self):
+        if hasattr(self, 'f'):
+            self.f.close()
+
+    def append(self,row):
+        self.appendRow(row)
+
+    def __setitem__(self, i, row):
+        l = self.length
+        if i<0: i+=l
+        self.putRow(i,row)
+
+    def __len__(self):
+        return self.length
+
+            
+if __name__ == '__main__':
+    pmat = PMat( 'tmp.pmat', 'w', fieldnames=['F1', 'F2'] )
+    pmat.append( [1, 2] )
+    pmat.append( [3, 4] )
+    pmat.close()
+
+    pmat = PMat( 'tmp.pmat', 'r' )
+    ar=load_pmat_as_array('tmp.pmat')
+    ds=load_pmat_as_array_dataset('tmp.pmat')
+    
+    print "PMat",pmat
+    print "PMat",pmat[:]
+    print "array",ar
+    print "ArrayDataSet",ds
+    for i in ds:
+        print i
+    save_array_dataset_as_pmat("tmp2.pmat",ds)
+    ds2=load_pmat_as_array_dataset('tmp2.pmat')
+    for i in ds2:
+        print i
+    # print "+++ tmp.pmat contains: "
+    # os.system( 'plearn vmat cat tmp.pmat' )
+    import shutil
+    for fname in ["tmp.pmat", "tmp2.pmat"]:
+        os.remove( fname )
+        if os.path.exists( fname+'.metadata' ):
+            shutil.rmtree( fname+'.metadata' )
--- a/test_dataset.py	Wed May 07 15:07:56 2008 -0400
+++ b/test_dataset.py	Wed May 07 15:08:18 2008 -0400
@@ -320,6 +320,11 @@
       #    - 'fieldtypes': a list of types (one per field)
 
     #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#????
+#        i=0
+#        for example in hstack([ds('x'),ds('y'),ds('z')]):
+#            example==ds[i]
+#            i+=1 
+#        del i,example
     #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#????
 
 
@@ -396,6 +401,10 @@
     print "test_ArrayFieldsDataSet"
     raise NotImplementedError()
 
+import pmat
+
+load_pmat_as_array_dataset("tmp.pmat")
+
 test1()
 test_LookupList()
 test_ArrayDataSet()