comparison pylearn/io/pmat.py @ 537:b054271b2504

new file structure layout, factories, etc.
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 12 Nov 2008 21:57:54 -0500
parents pmat.py@c2f17f231960
children 3f44379177b2
comparison
equal deleted inserted replaced
518:4aa7f74ea93f 537:b054271b2504
1 ## Automatically adapted for numpy.numarray Jun 13, 2007 by python_numarray_to_numpy (-xsm)
2
3 # PMat.py
4 # Copyright (C) 2005 Pascal Vincent
5 #
6 # Redistribution and use in source and binary forms, with or without
7 # modification, are permitted provided that the following conditions are met:
8 #
9 # 1. Redistributions of source code must retain the above copyright
10 # notice, this list of conditions and the following disclaimer.
11 #
12 # 2. Redistributions in binary form must reproduce the above copyright
13 # notice, this list of conditions and the following disclaimer in the
14 # documentation and/or other materials provided with the distribution.
15 #
16 # 3. The name of the authors may not be used to endorse or promote
17 # products derived from this software without specific prior written
18 # permission.
19 #
20 # THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR
21 # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
22 # OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
23 # NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
25 # TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26 # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27 # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28 # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 #
31 # This file is part of the PLearn library. For more information on the PLearn
32 # library, go to the PLearn Web site at www.plearn.org
33
34
35 # Author: Pascal Vincent
36
37 #import numarray, sys, os, os.path
38 import numpy.numarray, sys, os, os.path
39 import fpconst
40
41 def array_columns( a, cols ):
42 indices = None
43 if isinstance( cols, int ):
44 indices = [ cols ]
45 elif isinstance( cols, slice ):
46 #print cols
47 indices = range( *cols.indices(cols.stop) )
48 else:
49 indices = list( cols )
50
51 return numpy.numarray.take(a, indices, axis=1)
52
53 def load_pmat_as_array(fname):
54 s = file(fname,'rb').read()
55 formatstr = s[0:64]
56 datastr = s[64:]
57 structuretype, l, w, data_type, endianness = formatstr.split()
58
59 if data_type=='DOUBLE':
60 elemtype = 'd'
61 elif data_type=='FLOAT':
62 elemtype = 'f'
63 else:
64 raise ValueError('Invalid data type in file header: '+data_type)
65
66 if endianness=='LITTLE_ENDIAN':
67 byteorder = 'little'
68 elif endianness=='BIG_ENDIAN':
69 byteorder = 'big'
70 else:
71 raise ValueError('Invalid endianness in file header: '+endianness)
72
73 l = int(l)
74 w = int(w)
75 X = numpy.numarray.fromstring(datastr,elemtype, shape=(l,w) )
76 if byteorder!=sys.byteorder:
77 X.byteswap(True)
78 return X
79
80 def load_pmat_as_array_dataset(fname):
81 import dataset,lookup_list
82
83 #load the pmat as array
84 a=load_pmat_as_array(fname)
85
86 #load the fieldnames
87 fieldnames = []
88 fieldnamefile = os.path.join(fname+'.metadata','fieldnames')
89 if os.path.isfile(fieldnamefile):
90 f = open(fieldnamefile)
91 for row in f:
92 row = row.split()
93 if len(row)>0:
94 fieldnames.append(row[0])
95 f.close()
96 else:
97 self.fieldnames = [ "field_"+str(i) for i in range(a.shape[1]) ]
98
99 return dataset.ArrayDataSet(a,lookup_list.LookupList(fieldnames,[x for x in range(a.shape[1])]))
100
101 def load_amat_as_array_dataset(fname):
102 import dataset,lookup_list
103
104 #load the amat as array
105 (a,fieldnames)=readAMat(fname)
106
107 #load the fieldnames
108 if len(fieldnames)==0:
109 self.fieldnames = [ "field_"+str(i) for i in range(a.shape[1]) ]
110
111 return dataset.ArrayDataSet(a,lookup_list.LookupList(fieldnames,[x for x in range(a.shape[1])]))
112
113 def save_array_dataset_as_pmat(fname,ds):
114 ar=ds.data
115 save_array_as_pmat(fname,ar,ds.fieldNames())
116
117 def save_array_as_pmat( fname, ar, fieldnames=[] ):
118 s = file(fname,'wb')
119
120 length, width = ar.shape
121 if fieldnames:
122 assert len(fieldnames) == width
123 metadatadir = fname+'.metadata'
124 if not os.path.isdir(metadatadir):
125 os.mkdir(metadatadir)
126 fieldnamefile = os.path.join(metadatadir,'fieldnames')
127 f = open(fieldnamefile,'wb')
128 for name in fieldnames:
129 f.write(name+'\t0\n')
130 f.close()
131
132 header = 'MATRIX ' + str(length) + ' ' + str(width) + ' '
133 if ar.dtype.char=='d':
134 header += 'DOUBLE '
135 elemsize = 8
136
137 elif ar.dtype.char=='f':
138 header += 'FLOAT '
139 elemsize = 4
140
141 else:
142 raise TypeError('Unsupported typecode: %s' % ar.dtype.char)
143
144 rowsize = elemsize*width
145
146 if sys.byteorder=='little':
147 header += 'LITTLE_ENDIAN '
148 elif sys.byteorder=='big':
149 header += 'BIG_ENDIAN '
150 else:
151 raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder))
152
153 header += ' '*(63-len(header))+'\n'
154 s.write( header )
155 s.write( ar.tostring() )
156 s.close()
157
158
159 ####### Iterators ###########################################################
160
161 class VMatIt:
162 def __init__(self, vmat):
163 self.vmat = vmat
164 self.cur_row = 0
165
166 def __iter__(self):
167 return self
168
169 def next(self):
170 if self.cur_row==self.vmat.length:
171 raise StopIteration
172 row = self.vmat.getRow(self.cur_row)
173 self.cur_row += 1
174 return row
175
176 class ColumnIt:
177 def __init__(self, vmat, col):
178 self.vmat = vmat
179 self.col = col
180 self.cur_row = 0
181
182 def __iter__(self):
183 return self
184
185 def next(self):
186 if self.cur_row==self.vmat.length:
187 raise StopIteration
188 val = self.vmat[self.cur_row, self.col]
189 self.cur_row += 1
190 return val
191
192 ####### VMat classes ########################################################
193
194 class VMat:
195 def __iter__(self):
196 return VMatIt(self)
197
198 def __getitem__( self, key ):
199 if isinstance( key, slice ):
200 start, stop, step = key.start, key.stop, key.step
201 if step!=None:
202 raise IndexError('Extended slice with step not currently supported')
203
204 if start is None:
205 start = 0
206
207 l = self.length
208 if stop is None or stop > l:
209 stop = l
210
211 return self.getRows(start,stop-start)
212
213 elif isinstance( key, tuple ):
214 # Basically returns a SubVMatrix
215 assert len(key) == 2
216 rows = self.__getitem__( key[0] )
217
218 shape = rows.shape
219 if len(shape) == 1:
220 return rows[ key[1] ]
221
222 cols = key[1]
223 if isinstance(cols, slice):
224 start, stop, step = cols.start, cols.stop, cols.step
225 if start is None:
226 start = 0
227
228 if stop is None:
229 stop = self.width
230 elif stop < 0:
231 stop = self.width+stop
232
233 cols = slice(start, stop, step)
234
235 return array_columns(rows, cols)
236
237 elif isinstance( key, str ):
238 # The key is considered to be a fieldname and a column is
239 # returned.
240 try:
241 return array_columns( self.getRows(0,self.length),
242 self.fieldnames.index(key) )
243 except ValueError:
244 print >>sys.stderr, "Key is '%s' while fieldnames are:" % key
245 print >>sys.stderr, self.fieldnames
246 raise
247
248 else:
249 if key<0: key+=self.length
250 return self.getRow(key)
251
252 def getFieldIndex(self, fieldname):
253 try:
254 return self.fieldnames.index(fieldname)
255 except ValueError:
256 raise ValueError( "VMat has no field named %s. Field names: %s"
257 %(fieldname, ','.join(self.fieldnames)) )
258
259 class PMat( VMat ):
260
261 def __init__(self, fname, openmode='r', fieldnames=[], elemtype='d',
262 inputsize=-1, targetsize=-1, weightsize=-1, array = None):
263 self.fname = fname
264 self.inputsize = inputsize
265 self.targetsize = targetsize
266 self.weightsize = weightsize
267 if openmode=='r':
268 self.f = open(fname,'rb')
269 self.read_and_parse_header()
270 self.load_fieldnames()
271
272 elif openmode=='w':
273 self.f = open(fname,'w+b')
274 self.fieldnames = fieldnames
275 self.save_fieldnames()
276 self.length = 0
277 self.width = len(fieldnames)
278 self.elemtype = elemtype
279 self.swap_bytes = False
280 self.write_header()
281
282 elif openmode=='a':
283 self.f = open(fname,'r+b')
284 self.read_and_parse_header()
285 self.load_fieldnames()
286
287 else:
288 raise ValueError("Currently only supported openmodes are 'r', 'w' and 'a': "+repr(openmode)+" is not supported")
289
290 if array is not None:
291 shape = array.shape
292 if len(shape) == 1:
293 row_format = lambda r: [ r ]
294 elif len(shape) == 2:
295 row_format = lambda r: r
296
297 for row in array:
298 self.appendRow( row_format(row) )
299
300 def __del__(self):
301 self.close()
302
303 def write_header(self):
304 header = 'MATRIX ' + str(self.length) + ' ' + str(self.width) + ' '
305
306 if self.elemtype=='d':
307 header += 'DOUBLE '
308 self.elemsize = 8
309 elif self.elemtype=='f':
310 header += 'FLOAT '
311 self.elemsize = 4
312 else:
313 raise TypeError('Unsupported elemtype: '+repr(elemtype))
314 self.rowsize = self.elemsize*self.width
315
316 if sys.byteorder=='little':
317 header += 'LITTLE_ENDIAN '
318 elif sys.byteorder=='big':
319 header += 'BIG_ENDIAN '
320 else:
321 raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder))
322
323 header += ' '*(63-len(header))+'\n'
324
325 self.f.seek(0)
326 self.f.write(header)
327
328 def read_and_parse_header(self):
329 header = self.f.read(64)
330 mat_type, l, w, data_type, endianness = header.split()
331 if mat_type!='MATRIX':
332 raise ValueError('Invalid file header (should start with MATRIX)')
333 self.length = int(l)
334 self.width = int(w)
335 if endianness=='LITTLE_ENDIAN':
336 byteorder = 'little'
337 elif endianness=='BIG_ENDIAN':
338 byteorder = 'big'
339 else:
340 raise ValueError('Invalid endianness in file header: '+endianness)
341 self.swap_bytes = (byteorder!=sys.byteorder)
342
343 if data_type=='DOUBLE':
344 self.elemtype = 'd'
345 self.elemsize = 8
346 elif data_type=='FLOAT':
347 self.elemtype = 'f'
348 self.elemsize = 4
349 else:
350 raise ValueError('Invalid data type in file header: '+data_type)
351 self.rowsize = self.elemsize*self.width
352
353 def load_fieldnames(self):
354 self.fieldnames = []
355 fieldnamefile = os.path.join(self.fname+'.metadata','fieldnames')
356 if os.path.isfile(fieldnamefile):
357 f = open(fieldnamefile)
358 for row in f:
359 row = row.split()
360 if len(row)>0:
361 self.fieldnames.append(row[0])
362 f.close()
363 else:
364 self.fieldnames = [ "field_"+str(i) for i in range(self.width) ]
365
366 def save_fieldnames(self):
367 metadatadir = self.fname+'.metadata'
368 if not os.path.isdir(metadatadir):
369 os.mkdir(metadatadir)
370 fieldnamefile = os.path.join(metadatadir,'fieldnames')
371 f = open(fieldnamefile,'wb')
372 for name in self.fieldnames:
373 f.write(name+'\t0\n')
374 f.close()
375
376 def getRow(self,i):
377 if i<0 or i>=self.length:
378 raise IndexError('PMat index out of range')
379 self.f.seek(64+i*self.rowsize)
380 data = self.f.read(self.rowsize)
381 ar = numpy.numarray.fromstring(data, self.elemtype, (self.width,))
382 if self.swap_bytes:
383 ar.byteswap(True)
384 return ar
385
386 def getRows(self,i,l):
387 if i<0 or l<0 or i+l>self.length:
388 raise IndexError('PMat index out of range')
389 self.f.seek(64+i*self.rowsize)
390 data = self.f.read(l*self.rowsize)
391 ar = numpy.numarray.fromstring(data, self.elemtype, (l,self.width))
392 if self.swap_bytes:
393 ar.byteswap(True)
394 return ar
395
396 def checkzerorow(self,i):
397 if i<0 or i>self.length:
398 raise IndexError('PMat index out of range')
399 self.f.seek(64+i*self.rowsize)
400 data = self.f.read(self.rowsize)
401 ar = numpy.numarray.fromstring(data, self.elemtype, (len(data)/self.elemsize,))
402 if self.swap_bytes:
403 ar.byteswap(True)
404 for elem in ar:
405 if elem!=0:
406 return False
407 return True
408
409 def putRow(self,i,row):
410 if i<0 or i>=self.length:
411 raise IndexError('PMat index out of range')
412 if len(row)!=self.width:
413 raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')')
414 if i<0 or i>=self.length:
415 raise IndexError
416 if self.swap_bytes: # must make a copy and swap bytes
417 ar = numpy.numarray.numarray(row,type=self.elemtype)
418 ar.byteswap(True)
419 else: # asarray makes a copy if not already a numarray of the right type
420 ar = numpy.numarray.asarray(row,type=self.elemtype)
421 self.f.seek(64+i*self.rowsize)
422 self.f.write(ar.tostring())
423
424 def appendRow(self,row):
425 if len(row)!=self.width:
426 raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')')
427 if self.swap_bytes: # must make a copy and swap bytes
428 ar = numpy.numarray.numarray(row,type=self.elemtype)
429 ar.byteswap(True)
430 else: # asarray makes a copy if not already a numarray of the right type
431 ar = numpy.numarray.asarray(row,type=self.elemtype)
432
433 self.f.seek(64+self.length*self.rowsize)
434 self.f.write(ar.tostring())
435 self.length += 1
436 self.write_header() # update length in header
437
438 def flush(self):
439 self.f.flush()
440
441 def close(self):
442 if hasattr(self, 'f'):
443 self.f.close()
444
445 def append(self,row):
446 self.appendRow(row)
447
448 def __setitem__(self, i, row):
449 l = self.length
450 if i<0: i+=l
451 self.putRow(i,row)
452
453 def __len__(self):
454 return self.length
455
456
457
458 #copied from PLEARNDIR:python_modules/plearn/vmat/readAMat.py
459 def safefloat(str):
460 """Convert the given string to its float value. It is 'safe' in the sense
461 that missing values ('nan') will be properly converted to the corresponding
462 float value under all platforms, contrarily to 'float(str)'.
463 """
464 if str.lower() == 'nan':
465 return fpconst.NaN
466 else:
467 return float(str)
468
469 #copied from PLEARNDIR:python_modules/plearn/vmat/readAMat.py
470 def readAMat(amatname):
471 """Read a PLearn .amat file and return it as a numarray Array.
472
473 Return a tuple, with as the first argument the array itself, and as
474 the second argument the fieldnames (list of strings).
475 """
476 ### NOTE: this version is much faster than first creating the array and
477 ### updating each row as it is read... Bizarrely enough
478 f = open(amatname)
479 a = []
480 fieldnames = []
481 for line in f:
482 if line.startswith("#size:"):
483 (length,width) = line[6:].strip().split()
484 elif line.startswith("#sizes:"): # ignore input/target/weight/extra sizes
485 continue
486
487 elif line.startswith("#:"):
488 fieldnames = line[2:].strip().split()
489 pass
490 elif not line.startswith('#'):
491 # Add all non-comment lines.
492 row = [ safefloat(x) for x in line.strip().split() ]
493 if row:
494 a.append(row)
495
496 f.close()
497 return numpy.numarray.array(a), fieldnames
498
499
500 if __name__ == '__main__':
501 pmat = PMat( 'tmp.pmat', 'w', fieldnames=['F1', 'F2'] )
502 pmat.append( [1, 2] )
503 pmat.append( [3, 4] )
504 pmat.close()
505
506 pmat = PMat( 'tmp.pmat', 'r' )
507 ar=load_pmat_as_array('tmp.pmat')
508 ds=load_pmat_as_array_dataset('tmp.pmat')
509
510 print "PMat",pmat
511 print "PMat",pmat[:]
512 print "array",ar
513 print "ArrayDataSet",ds
514 for i in ds:
515 print i
516 save_array_dataset_as_pmat("tmp2.pmat",ds)
517 ds2=load_pmat_as_array_dataset('tmp2.pmat')
518 for i in ds2:
519 print i
520 # print "+++ tmp.pmat contains: "
521 # os.system( 'plearn vmat cat tmp.pmat' )
522 import shutil
523 for fname in ["tmp.pmat", "tmp2.pmat"]:
524 os.remove( fname )
525 if os.path.exists( fname+'.metadata' ):
526 shutil.rmtree( fname+'.metadata' )