Mercurial > pylearn
comparison pylearn/io/pmat.py @ 537:b054271b2504
new file structure layout, factories, etc.
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 12 Nov 2008 21:57:54 -0500 |
parents | pmat.py@c2f17f231960 |
children | 3f44379177b2 |
comparison
equal
deleted
inserted
replaced
518:4aa7f74ea93f | 537:b054271b2504 |
---|---|
1 ## Automatically adapted for numpy.numarray Jun 13, 2007 by python_numarray_to_numpy (-xsm) | |
2 | |
3 # PMat.py | |
4 # Copyright (C) 2005 Pascal Vincent | |
5 # | |
6 # Redistribution and use in source and binary forms, with or without | |
7 # modification, are permitted provided that the following conditions are met: | |
8 # | |
9 # 1. Redistributions of source code must retain the above copyright | |
10 # notice, this list of conditions and the following disclaimer. | |
11 # | |
12 # 2. Redistributions in binary form must reproduce the above copyright | |
13 # notice, this list of conditions and the following disclaimer in the | |
14 # documentation and/or other materials provided with the distribution. | |
15 # | |
16 # 3. The name of the authors may not be used to endorse or promote | |
17 # products derived from this software without specific prior written | |
18 # permission. | |
19 # | |
20 # THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR | |
21 # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES | |
22 # OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN | |
23 # NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED | |
25 # TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | |
26 # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF | |
27 # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING | |
28 # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
29 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
30 # | |
31 # This file is part of the PLearn library. For more information on the PLearn | |
32 # library, go to the PLearn Web site at www.plearn.org | |
33 | |
34 | |
35 # Author: Pascal Vincent | |
36 | |
37 #import numarray, sys, os, os.path | |
38 import numpy.numarray, sys, os, os.path | |
39 import fpconst | |
40 | |
41 def array_columns( a, cols ): | |
42 indices = None | |
43 if isinstance( cols, int ): | |
44 indices = [ cols ] | |
45 elif isinstance( cols, slice ): | |
46 #print cols | |
47 indices = range( *cols.indices(cols.stop) ) | |
48 else: | |
49 indices = list( cols ) | |
50 | |
51 return numpy.numarray.take(a, indices, axis=1) | |
52 | |
53 def load_pmat_as_array(fname): | |
54 s = file(fname,'rb').read() | |
55 formatstr = s[0:64] | |
56 datastr = s[64:] | |
57 structuretype, l, w, data_type, endianness = formatstr.split() | |
58 | |
59 if data_type=='DOUBLE': | |
60 elemtype = 'd' | |
61 elif data_type=='FLOAT': | |
62 elemtype = 'f' | |
63 else: | |
64 raise ValueError('Invalid data type in file header: '+data_type) | |
65 | |
66 if endianness=='LITTLE_ENDIAN': | |
67 byteorder = 'little' | |
68 elif endianness=='BIG_ENDIAN': | |
69 byteorder = 'big' | |
70 else: | |
71 raise ValueError('Invalid endianness in file header: '+endianness) | |
72 | |
73 l = int(l) | |
74 w = int(w) | |
75 X = numpy.numarray.fromstring(datastr,elemtype, shape=(l,w) ) | |
76 if byteorder!=sys.byteorder: | |
77 X.byteswap(True) | |
78 return X | |
79 | |
80 def load_pmat_as_array_dataset(fname): | |
81 import dataset,lookup_list | |
82 | |
83 #load the pmat as array | |
84 a=load_pmat_as_array(fname) | |
85 | |
86 #load the fieldnames | |
87 fieldnames = [] | |
88 fieldnamefile = os.path.join(fname+'.metadata','fieldnames') | |
89 if os.path.isfile(fieldnamefile): | |
90 f = open(fieldnamefile) | |
91 for row in f: | |
92 row = row.split() | |
93 if len(row)>0: | |
94 fieldnames.append(row[0]) | |
95 f.close() | |
96 else: | |
97 self.fieldnames = [ "field_"+str(i) for i in range(a.shape[1]) ] | |
98 | |
99 return dataset.ArrayDataSet(a,lookup_list.LookupList(fieldnames,[x for x in range(a.shape[1])])) | |
100 | |
101 def load_amat_as_array_dataset(fname): | |
102 import dataset,lookup_list | |
103 | |
104 #load the amat as array | |
105 (a,fieldnames)=readAMat(fname) | |
106 | |
107 #load the fieldnames | |
108 if len(fieldnames)==0: | |
109 self.fieldnames = [ "field_"+str(i) for i in range(a.shape[1]) ] | |
110 | |
111 return dataset.ArrayDataSet(a,lookup_list.LookupList(fieldnames,[x for x in range(a.shape[1])])) | |
112 | |
113 def save_array_dataset_as_pmat(fname,ds): | |
114 ar=ds.data | |
115 save_array_as_pmat(fname,ar,ds.fieldNames()) | |
116 | |
117 def save_array_as_pmat( fname, ar, fieldnames=[] ): | |
118 s = file(fname,'wb') | |
119 | |
120 length, width = ar.shape | |
121 if fieldnames: | |
122 assert len(fieldnames) == width | |
123 metadatadir = fname+'.metadata' | |
124 if not os.path.isdir(metadatadir): | |
125 os.mkdir(metadatadir) | |
126 fieldnamefile = os.path.join(metadatadir,'fieldnames') | |
127 f = open(fieldnamefile,'wb') | |
128 for name in fieldnames: | |
129 f.write(name+'\t0\n') | |
130 f.close() | |
131 | |
132 header = 'MATRIX ' + str(length) + ' ' + str(width) + ' ' | |
133 if ar.dtype.char=='d': | |
134 header += 'DOUBLE ' | |
135 elemsize = 8 | |
136 | |
137 elif ar.dtype.char=='f': | |
138 header += 'FLOAT ' | |
139 elemsize = 4 | |
140 | |
141 else: | |
142 raise TypeError('Unsupported typecode: %s' % ar.dtype.char) | |
143 | |
144 rowsize = elemsize*width | |
145 | |
146 if sys.byteorder=='little': | |
147 header += 'LITTLE_ENDIAN ' | |
148 elif sys.byteorder=='big': | |
149 header += 'BIG_ENDIAN ' | |
150 else: | |
151 raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder)) | |
152 | |
153 header += ' '*(63-len(header))+'\n' | |
154 s.write( header ) | |
155 s.write( ar.tostring() ) | |
156 s.close() | |
157 | |
158 | |
159 ####### Iterators ########################################################### | |
160 | |
161 class VMatIt: | |
162 def __init__(self, vmat): | |
163 self.vmat = vmat | |
164 self.cur_row = 0 | |
165 | |
166 def __iter__(self): | |
167 return self | |
168 | |
169 def next(self): | |
170 if self.cur_row==self.vmat.length: | |
171 raise StopIteration | |
172 row = self.vmat.getRow(self.cur_row) | |
173 self.cur_row += 1 | |
174 return row | |
175 | |
176 class ColumnIt: | |
177 def __init__(self, vmat, col): | |
178 self.vmat = vmat | |
179 self.col = col | |
180 self.cur_row = 0 | |
181 | |
182 def __iter__(self): | |
183 return self | |
184 | |
185 def next(self): | |
186 if self.cur_row==self.vmat.length: | |
187 raise StopIteration | |
188 val = self.vmat[self.cur_row, self.col] | |
189 self.cur_row += 1 | |
190 return val | |
191 | |
192 ####### VMat classes ######################################################## | |
193 | |
194 class VMat: | |
195 def __iter__(self): | |
196 return VMatIt(self) | |
197 | |
198 def __getitem__( self, key ): | |
199 if isinstance( key, slice ): | |
200 start, stop, step = key.start, key.stop, key.step | |
201 if step!=None: | |
202 raise IndexError('Extended slice with step not currently supported') | |
203 | |
204 if start is None: | |
205 start = 0 | |
206 | |
207 l = self.length | |
208 if stop is None or stop > l: | |
209 stop = l | |
210 | |
211 return self.getRows(start,stop-start) | |
212 | |
213 elif isinstance( key, tuple ): | |
214 # Basically returns a SubVMatrix | |
215 assert len(key) == 2 | |
216 rows = self.__getitem__( key[0] ) | |
217 | |
218 shape = rows.shape | |
219 if len(shape) == 1: | |
220 return rows[ key[1] ] | |
221 | |
222 cols = key[1] | |
223 if isinstance(cols, slice): | |
224 start, stop, step = cols.start, cols.stop, cols.step | |
225 if start is None: | |
226 start = 0 | |
227 | |
228 if stop is None: | |
229 stop = self.width | |
230 elif stop < 0: | |
231 stop = self.width+stop | |
232 | |
233 cols = slice(start, stop, step) | |
234 | |
235 return array_columns(rows, cols) | |
236 | |
237 elif isinstance( key, str ): | |
238 # The key is considered to be a fieldname and a column is | |
239 # returned. | |
240 try: | |
241 return array_columns( self.getRows(0,self.length), | |
242 self.fieldnames.index(key) ) | |
243 except ValueError: | |
244 print >>sys.stderr, "Key is '%s' while fieldnames are:" % key | |
245 print >>sys.stderr, self.fieldnames | |
246 raise | |
247 | |
248 else: | |
249 if key<0: key+=self.length | |
250 return self.getRow(key) | |
251 | |
252 def getFieldIndex(self, fieldname): | |
253 try: | |
254 return self.fieldnames.index(fieldname) | |
255 except ValueError: | |
256 raise ValueError( "VMat has no field named %s. Field names: %s" | |
257 %(fieldname, ','.join(self.fieldnames)) ) | |
258 | |
259 class PMat( VMat ): | |
260 | |
261 def __init__(self, fname, openmode='r', fieldnames=[], elemtype='d', | |
262 inputsize=-1, targetsize=-1, weightsize=-1, array = None): | |
263 self.fname = fname | |
264 self.inputsize = inputsize | |
265 self.targetsize = targetsize | |
266 self.weightsize = weightsize | |
267 if openmode=='r': | |
268 self.f = open(fname,'rb') | |
269 self.read_and_parse_header() | |
270 self.load_fieldnames() | |
271 | |
272 elif openmode=='w': | |
273 self.f = open(fname,'w+b') | |
274 self.fieldnames = fieldnames | |
275 self.save_fieldnames() | |
276 self.length = 0 | |
277 self.width = len(fieldnames) | |
278 self.elemtype = elemtype | |
279 self.swap_bytes = False | |
280 self.write_header() | |
281 | |
282 elif openmode=='a': | |
283 self.f = open(fname,'r+b') | |
284 self.read_and_parse_header() | |
285 self.load_fieldnames() | |
286 | |
287 else: | |
288 raise ValueError("Currently only supported openmodes are 'r', 'w' and 'a': "+repr(openmode)+" is not supported") | |
289 | |
290 if array is not None: | |
291 shape = array.shape | |
292 if len(shape) == 1: | |
293 row_format = lambda r: [ r ] | |
294 elif len(shape) == 2: | |
295 row_format = lambda r: r | |
296 | |
297 for row in array: | |
298 self.appendRow( row_format(row) ) | |
299 | |
300 def __del__(self): | |
301 self.close() | |
302 | |
303 def write_header(self): | |
304 header = 'MATRIX ' + str(self.length) + ' ' + str(self.width) + ' ' | |
305 | |
306 if self.elemtype=='d': | |
307 header += 'DOUBLE ' | |
308 self.elemsize = 8 | |
309 elif self.elemtype=='f': | |
310 header += 'FLOAT ' | |
311 self.elemsize = 4 | |
312 else: | |
313 raise TypeError('Unsupported elemtype: '+repr(elemtype)) | |
314 self.rowsize = self.elemsize*self.width | |
315 | |
316 if sys.byteorder=='little': | |
317 header += 'LITTLE_ENDIAN ' | |
318 elif sys.byteorder=='big': | |
319 header += 'BIG_ENDIAN ' | |
320 else: | |
321 raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder)) | |
322 | |
323 header += ' '*(63-len(header))+'\n' | |
324 | |
325 self.f.seek(0) | |
326 self.f.write(header) | |
327 | |
328 def read_and_parse_header(self): | |
329 header = self.f.read(64) | |
330 mat_type, l, w, data_type, endianness = header.split() | |
331 if mat_type!='MATRIX': | |
332 raise ValueError('Invalid file header (should start with MATRIX)') | |
333 self.length = int(l) | |
334 self.width = int(w) | |
335 if endianness=='LITTLE_ENDIAN': | |
336 byteorder = 'little' | |
337 elif endianness=='BIG_ENDIAN': | |
338 byteorder = 'big' | |
339 else: | |
340 raise ValueError('Invalid endianness in file header: '+endianness) | |
341 self.swap_bytes = (byteorder!=sys.byteorder) | |
342 | |
343 if data_type=='DOUBLE': | |
344 self.elemtype = 'd' | |
345 self.elemsize = 8 | |
346 elif data_type=='FLOAT': | |
347 self.elemtype = 'f' | |
348 self.elemsize = 4 | |
349 else: | |
350 raise ValueError('Invalid data type in file header: '+data_type) | |
351 self.rowsize = self.elemsize*self.width | |
352 | |
353 def load_fieldnames(self): | |
354 self.fieldnames = [] | |
355 fieldnamefile = os.path.join(self.fname+'.metadata','fieldnames') | |
356 if os.path.isfile(fieldnamefile): | |
357 f = open(fieldnamefile) | |
358 for row in f: | |
359 row = row.split() | |
360 if len(row)>0: | |
361 self.fieldnames.append(row[0]) | |
362 f.close() | |
363 else: | |
364 self.fieldnames = [ "field_"+str(i) for i in range(self.width) ] | |
365 | |
366 def save_fieldnames(self): | |
367 metadatadir = self.fname+'.metadata' | |
368 if not os.path.isdir(metadatadir): | |
369 os.mkdir(metadatadir) | |
370 fieldnamefile = os.path.join(metadatadir,'fieldnames') | |
371 f = open(fieldnamefile,'wb') | |
372 for name in self.fieldnames: | |
373 f.write(name+'\t0\n') | |
374 f.close() | |
375 | |
376 def getRow(self,i): | |
377 if i<0 or i>=self.length: | |
378 raise IndexError('PMat index out of range') | |
379 self.f.seek(64+i*self.rowsize) | |
380 data = self.f.read(self.rowsize) | |
381 ar = numpy.numarray.fromstring(data, self.elemtype, (self.width,)) | |
382 if self.swap_bytes: | |
383 ar.byteswap(True) | |
384 return ar | |
385 | |
386 def getRows(self,i,l): | |
387 if i<0 or l<0 or i+l>self.length: | |
388 raise IndexError('PMat index out of range') | |
389 self.f.seek(64+i*self.rowsize) | |
390 data = self.f.read(l*self.rowsize) | |
391 ar = numpy.numarray.fromstring(data, self.elemtype, (l,self.width)) | |
392 if self.swap_bytes: | |
393 ar.byteswap(True) | |
394 return ar | |
395 | |
396 def checkzerorow(self,i): | |
397 if i<0 or i>self.length: | |
398 raise IndexError('PMat index out of range') | |
399 self.f.seek(64+i*self.rowsize) | |
400 data = self.f.read(self.rowsize) | |
401 ar = numpy.numarray.fromstring(data, self.elemtype, (len(data)/self.elemsize,)) | |
402 if self.swap_bytes: | |
403 ar.byteswap(True) | |
404 for elem in ar: | |
405 if elem!=0: | |
406 return False | |
407 return True | |
408 | |
409 def putRow(self,i,row): | |
410 if i<0 or i>=self.length: | |
411 raise IndexError('PMat index out of range') | |
412 if len(row)!=self.width: | |
413 raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')') | |
414 if i<0 or i>=self.length: | |
415 raise IndexError | |
416 if self.swap_bytes: # must make a copy and swap bytes | |
417 ar = numpy.numarray.numarray(row,type=self.elemtype) | |
418 ar.byteswap(True) | |
419 else: # asarray makes a copy if not already a numarray of the right type | |
420 ar = numpy.numarray.asarray(row,type=self.elemtype) | |
421 self.f.seek(64+i*self.rowsize) | |
422 self.f.write(ar.tostring()) | |
423 | |
424 def appendRow(self,row): | |
425 if len(row)!=self.width: | |
426 raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')') | |
427 if self.swap_bytes: # must make a copy and swap bytes | |
428 ar = numpy.numarray.numarray(row,type=self.elemtype) | |
429 ar.byteswap(True) | |
430 else: # asarray makes a copy if not already a numarray of the right type | |
431 ar = numpy.numarray.asarray(row,type=self.elemtype) | |
432 | |
433 self.f.seek(64+self.length*self.rowsize) | |
434 self.f.write(ar.tostring()) | |
435 self.length += 1 | |
436 self.write_header() # update length in header | |
437 | |
438 def flush(self): | |
439 self.f.flush() | |
440 | |
441 def close(self): | |
442 if hasattr(self, 'f'): | |
443 self.f.close() | |
444 | |
445 def append(self,row): | |
446 self.appendRow(row) | |
447 | |
448 def __setitem__(self, i, row): | |
449 l = self.length | |
450 if i<0: i+=l | |
451 self.putRow(i,row) | |
452 | |
453 def __len__(self): | |
454 return self.length | |
455 | |
456 | |
457 | |
458 #copied from PLEARNDIR:python_modules/plearn/vmat/readAMat.py | |
459 def safefloat(str): | |
460 """Convert the given string to its float value. It is 'safe' in the sense | |
461 that missing values ('nan') will be properly converted to the corresponding | |
462 float value under all platforms, contrarily to 'float(str)'. | |
463 """ | |
464 if str.lower() == 'nan': | |
465 return fpconst.NaN | |
466 else: | |
467 return float(str) | |
468 | |
469 #copied from PLEARNDIR:python_modules/plearn/vmat/readAMat.py | |
470 def readAMat(amatname): | |
471 """Read a PLearn .amat file and return it as a numarray Array. | |
472 | |
473 Return a tuple, with as the first argument the array itself, and as | |
474 the second argument the fieldnames (list of strings). | |
475 """ | |
476 ### NOTE: this version is much faster than first creating the array and | |
477 ### updating each row as it is read... Bizarrely enough | |
478 f = open(amatname) | |
479 a = [] | |
480 fieldnames = [] | |
481 for line in f: | |
482 if line.startswith("#size:"): | |
483 (length,width) = line[6:].strip().split() | |
484 elif line.startswith("#sizes:"): # ignore input/target/weight/extra sizes | |
485 continue | |
486 | |
487 elif line.startswith("#:"): | |
488 fieldnames = line[2:].strip().split() | |
489 pass | |
490 elif not line.startswith('#'): | |
491 # Add all non-comment lines. | |
492 row = [ safefloat(x) for x in line.strip().split() ] | |
493 if row: | |
494 a.append(row) | |
495 | |
496 f.close() | |
497 return numpy.numarray.array(a), fieldnames | |
498 | |
499 | |
500 if __name__ == '__main__': | |
501 pmat = PMat( 'tmp.pmat', 'w', fieldnames=['F1', 'F2'] ) | |
502 pmat.append( [1, 2] ) | |
503 pmat.append( [3, 4] ) | |
504 pmat.close() | |
505 | |
506 pmat = PMat( 'tmp.pmat', 'r' ) | |
507 ar=load_pmat_as_array('tmp.pmat') | |
508 ds=load_pmat_as_array_dataset('tmp.pmat') | |
509 | |
510 print "PMat",pmat | |
511 print "PMat",pmat[:] | |
512 print "array",ar | |
513 print "ArrayDataSet",ds | |
514 for i in ds: | |
515 print i | |
516 save_array_dataset_as_pmat("tmp2.pmat",ds) | |
517 ds2=load_pmat_as_array_dataset('tmp2.pmat') | |
518 for i in ds2: | |
519 print i | |
520 # print "+++ tmp.pmat contains: " | |
521 # os.system( 'plearn vmat cat tmp.pmat' ) | |
522 import shutil | |
523 for fname in ["tmp.pmat", "tmp2.pmat"]: | |
524 os.remove( fname ) | |
525 if os.path.exists( fname+'.metadata' ): | |
526 shutil.rmtree( fname+'.metadata' ) |