Mercurial > pylearn
view pmat.py @ 240:97f35d586727
more test. Are we wanting to limit the size of minibatch to the size of the dataset?
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Thu, 29 May 2008 10:42:29 -0400 |
parents | 9330d941fa1f |
children | c2f17f231960 |
line wrap: on
line source
## Automatically adapted for numpy.numarray Jun 13, 2007 by python_numarray_to_numpy (-xsm) # PMat.py # Copyright (C) 2005 Pascal Vincent # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # 3. The name of the authors may not be used to endorse or promote # products derived from this software without specific prior written # permission. # # THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES # OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN # NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED # TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # This file is part of the PLearn library. For more information on the PLearn # library, go to the PLearn Web site at www.plearn.org # Author: Pascal Vincent #import numarray, sys, os, os.path import numpy.numarray, sys, os, os.path def array_columns( a, cols ): indices = None if isinstance( cols, int ): indices = [ cols ] elif isinstance( cols, slice ): #print cols indices = range( *cols.indices(cols.stop) ) else: indices = list( cols ) return numpy.numarray.take(a, indices, axis=1) def load_pmat_as_array(fname): s = file(fname,'rb').read() formatstr = s[0:64] datastr = s[64:] structuretype, l, w, data_type, endianness = formatstr.split() if data_type=='DOUBLE': elemtype = 'd' elif data_type=='FLOAT': elemtype = 'f' else: raise ValueError('Invalid data type in file header: '+data_type) if endianness=='LITTLE_ENDIAN': byteorder = 'little' elif endianness=='BIG_ENDIAN': byteorder = 'big' else: raise ValueError('Invalid endianness in file header: '+endianness) l = int(l) w = int(w) X = numpy.numarray.fromstring(datastr,elemtype, shape=(l,w) ) if byteorder!=sys.byteorder: X.byteswap(True) return X def load_pmat_as_array_dataset(fname): import dataset,lookup_list #load the pmat as array a=load_pmat_as_array(fname) #load the fieldnames fieldnames = [] fieldnamefile = os.path.join(fname+'.metadata','fieldnames') if os.path.isfile(fieldnamefile): f = open(fieldnamefile) for row in f: row = row.split() if len(row)>0: fieldnames.append(row[0]) f.close() else: self.fieldnames = [ "field_"+str(i) for i in range(a.shape[1]) ] return dataset.ArrayDataSet(a,lookup_list.LookupList(fieldnames,[x for x in range(a.shape[1])])) def save_array_dataset_as_pmat(fname,ds): ar=ds.data save_array_as_pmat(fname,ar,ds.fieldNames()) def save_array_as_pmat( fname, ar, fieldnames=[] ): s = file(fname,'wb') length, width = ar.shape if fieldnames: assert len(fieldnames) == width metadatadir = fname+'.metadata' if not os.path.isdir(metadatadir): os.mkdir(metadatadir) fieldnamefile = os.path.join(metadatadir,'fieldnames') f = open(fieldnamefile,'wb') for name in fieldnames: f.write(name+'\t0\n') f.close() header = 'MATRIX ' + str(length) + ' ' + str(width) + ' ' if ar.dtype.char=='d': header += 'DOUBLE ' elemsize = 8 elif ar.dtype.char=='f': header += 'FLOAT ' elemsize = 4 else: raise TypeError('Unsupported typecode: %s' % ar.dtype.char) rowsize = elemsize*width if sys.byteorder=='little': header += 'LITTLE_ENDIAN ' elif sys.byteorder=='big': header += 'BIG_ENDIAN ' else: raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder)) header += ' '*(63-len(header))+'\n' s.write( header ) s.write( ar.tostring() ) s.close() ####### Iterators ########################################################### class VMatIt: def __init__(self, vmat): self.vmat = vmat self.cur_row = 0 def __iter__(self): return self def next(self): if self.cur_row==self.vmat.length: raise StopIteration row = self.vmat.getRow(self.cur_row) self.cur_row += 1 return row class ColumnIt: def __init__(self, vmat, col): self.vmat = vmat self.col = col self.cur_row = 0 def __iter__(self): return self def next(self): if self.cur_row==self.vmat.length: raise StopIteration val = self.vmat[self.cur_row, self.col] self.cur_row += 1 return val ####### VMat classes ######################################################## class VMat: def __iter__(self): return VMatIt(self) def __getitem__( self, key ): if isinstance( key, slice ): start, stop, step = key.start, key.stop, key.step if step!=None: raise IndexError('Extended slice with step not currently supported') if start is None: start = 0 l = self.length if stop is None or stop > l: stop = l return self.getRows(start,stop-start) elif isinstance( key, tuple ): # Basically returns a SubVMatrix assert len(key) == 2 rows = self.__getitem__( key[0] ) shape = rows.shape if len(shape) == 1: return rows[ key[1] ] cols = key[1] if isinstance(cols, slice): start, stop, step = cols.start, cols.stop, cols.step if start is None: start = 0 if stop is None: stop = self.width elif stop < 0: stop = self.width+stop cols = slice(start, stop, step) return array_columns(rows, cols) elif isinstance( key, str ): # The key is considered to be a fieldname and a column is # returned. try: return array_columns( self.getRows(0,self.length), self.fieldnames.index(key) ) except ValueError: print >>sys.stderr, "Key is '%s' while fieldnames are:" % key print >>sys.stderr, self.fieldnames raise else: if key<0: key+=self.length return self.getRow(key) def getFieldIndex(self, fieldname): try: return self.fieldnames.index(fieldname) except ValueError: raise ValueError( "VMat has no field named %s. Field names: %s" %(fieldname, ','.join(self.fieldnames)) ) class PMat( VMat ): def __init__(self, fname, openmode='r', fieldnames=[], elemtype='d', inputsize=-1, targetsize=-1, weightsize=-1, array = None): self.fname = fname self.inputsize = inputsize self.targetsize = targetsize self.weightsize = weightsize if openmode=='r': self.f = open(fname,'rb') self.read_and_parse_header() self.load_fieldnames() elif openmode=='w': self.f = open(fname,'w+b') self.fieldnames = fieldnames self.save_fieldnames() self.length = 0 self.width = len(fieldnames) self.elemtype = elemtype self.swap_bytes = False self.write_header() elif openmode=='a': self.f = open(fname,'r+b') self.read_and_parse_header() self.load_fieldnames() else: raise ValueError("Currently only supported openmodes are 'r', 'w' and 'a': "+repr(openmode)+" is not supported") if array is not None: shape = array.shape if len(shape) == 1: row_format = lambda r: [ r ] elif len(shape) == 2: row_format = lambda r: r for row in array: self.appendRow( row_format(row) ) def __del__(self): self.close() def write_header(self): header = 'MATRIX ' + str(self.length) + ' ' + str(self.width) + ' ' if self.elemtype=='d': header += 'DOUBLE ' self.elemsize = 8 elif self.elemtype=='f': header += 'FLOAT ' self.elemsize = 4 else: raise TypeError('Unsupported elemtype: '+repr(elemtype)) self.rowsize = self.elemsize*self.width if sys.byteorder=='little': header += 'LITTLE_ENDIAN ' elif sys.byteorder=='big': header += 'BIG_ENDIAN ' else: raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder)) header += ' '*(63-len(header))+'\n' self.f.seek(0) self.f.write(header) def read_and_parse_header(self): header = self.f.read(64) mat_type, l, w, data_type, endianness = header.split() if mat_type!='MATRIX': raise ValueError('Invalid file header (should start with MATRIX)') self.length = int(l) self.width = int(w) if endianness=='LITTLE_ENDIAN': byteorder = 'little' elif endianness=='BIG_ENDIAN': byteorder = 'big' else: raise ValueError('Invalid endianness in file header: '+endianness) self.swap_bytes = (byteorder!=sys.byteorder) if data_type=='DOUBLE': self.elemtype = 'd' self.elemsize = 8 elif data_type=='FLOAT': self.elemtype = 'f' self.elemsize = 4 else: raise ValueError('Invalid data type in file header: '+data_type) self.rowsize = self.elemsize*self.width def load_fieldnames(self): self.fieldnames = [] fieldnamefile = os.path.join(self.fname+'.metadata','fieldnames') if os.path.isfile(fieldnamefile): f = open(fieldnamefile) for row in f: row = row.split() if len(row)>0: self.fieldnames.append(row[0]) f.close() else: self.fieldnames = [ "field_"+str(i) for i in range(self.width) ] def save_fieldnames(self): metadatadir = self.fname+'.metadata' if not os.path.isdir(metadatadir): os.mkdir(metadatadir) fieldnamefile = os.path.join(metadatadir,'fieldnames') f = open(fieldnamefile,'wb') for name in self.fieldnames: f.write(name+'\t0\n') f.close() def getRow(self,i): if i<0 or i>=self.length: raise IndexError('PMat index out of range') self.f.seek(64+i*self.rowsize) data = self.f.read(self.rowsize) ar = numpy.numarray.fromstring(data, self.elemtype, (self.width,)) if self.swap_bytes: ar.byteswap(True) return ar def getRows(self,i,l): if i<0 or l<0 or i+l>self.length: raise IndexError('PMat index out of range') self.f.seek(64+i*self.rowsize) data = self.f.read(l*self.rowsize) ar = numpy.numarray.fromstring(data, self.elemtype, (l,self.width)) if self.swap_bytes: ar.byteswap(True) return ar def checkzerorow(self,i): if i<0 or i>self.length: raise IndexError('PMat index out of range') self.f.seek(64+i*self.rowsize) data = self.f.read(self.rowsize) ar = numpy.numarray.fromstring(data, self.elemtype, (len(data)/self.elemsize,)) if self.swap_bytes: ar.byteswap(True) for elem in ar: if elem!=0: return False return True def putRow(self,i,row): if i<0 or i>=self.length: raise IndexError('PMat index out of range') if len(row)!=self.width: raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')') if i<0 or i>=self.length: raise IndexError if self.swap_bytes: # must make a copy and swap bytes ar = numpy.numarray.numarray(row,type=self.elemtype) ar.byteswap(True) else: # asarray makes a copy if not already a numarray of the right type ar = numpy.numarray.asarray(row,type=self.elemtype) self.f.seek(64+i*self.rowsize) self.f.write(ar.tostring()) def appendRow(self,row): if len(row)!=self.width: raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')') if self.swap_bytes: # must make a copy and swap bytes ar = numpy.numarray.numarray(row,type=self.elemtype) ar.byteswap(True) else: # asarray makes a copy if not already a numarray of the right type ar = numpy.numarray.asarray(row,type=self.elemtype) self.f.seek(64+self.length*self.rowsize) self.f.write(ar.tostring()) self.length += 1 self.write_header() # update length in header def flush(self): self.f.flush() def close(self): if hasattr(self, 'f'): self.f.close() def append(self,row): self.appendRow(row) def __setitem__(self, i, row): l = self.length if i<0: i+=l self.putRow(i,row) def __len__(self): return self.length if __name__ == '__main__': pmat = PMat( 'tmp.pmat', 'w', fieldnames=['F1', 'F2'] ) pmat.append( [1, 2] ) pmat.append( [3, 4] ) pmat.close() pmat = PMat( 'tmp.pmat', 'r' ) ar=load_pmat_as_array('tmp.pmat') ds=load_pmat_as_array_dataset('tmp.pmat') print "PMat",pmat print "PMat",pmat[:] print "array",ar print "ArrayDataSet",ds for i in ds: print i save_array_dataset_as_pmat("tmp2.pmat",ds) ds2=load_pmat_as_array_dataset('tmp2.pmat') for i in ds2: print i # print "+++ tmp.pmat contains: " # os.system( 'plearn vmat cat tmp.pmat' ) import shutil for fname in ["tmp.pmat", "tmp2.pmat"]: os.remove( fname ) if os.path.exists( fname+'.metadata' ): shutil.rmtree( fname+'.metadata' )