Mercurial > pylearn
comparison pmat.py @ 114:d6d42a0c1275
file copied from PLearn/python_modules/plearn/vmat/PMat.py
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Wed, 07 May 2008 12:18:11 -0400 |
parents | |
children | 01aa97a2212d |
comparison
equal
deleted
inserted
replaced
113:b6bc1e769b36 | 114:d6d42a0c1275 |
---|---|
1 ## Automatically adapted for numpy.numarray Jun 13, 2007 by python_numarray_to_numpy (-xsm) | |
2 | |
3 # PMat.py | |
4 # Copyright (C) 2005 Pascal Vincent | |
5 # | |
6 # Redistribution and use in source and binary forms, with or without | |
7 # modification, are permitted provided that the following conditions are met: | |
8 # | |
9 # 1. Redistributions of source code must retain the above copyright | |
10 # notice, this list of conditions and the following disclaimer. | |
11 # | |
12 # 2. Redistributions in binary form must reproduce the above copyright | |
13 # notice, this list of conditions and the following disclaimer in the | |
14 # documentation and/or other materials provided with the distribution. | |
15 # | |
16 # 3. The name of the authors may not be used to endorse or promote | |
17 # products derived from this software without specific prior written | |
18 # permission. | |
19 # | |
20 # THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR | |
21 # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES | |
22 # OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN | |
23 # NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED | |
25 # TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | |
26 # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF | |
27 # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING | |
28 # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
29 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
30 # | |
31 # This file is part of the PLearn library. For more information on the PLearn | |
32 # library, go to the PLearn Web site at www.plearn.org | |
33 | |
34 | |
35 # Author: Pascal Vincent | |
36 | |
37 #import numarray, sys, os, os.path | |
38 import numpy.numarray, sys, os, os.path | |
39 pyplearn_import_failed = False | |
40 try: | |
41 from plearn.pyplearn.plearn_repr import plearn_repr, format_list_elements | |
42 except ImportError: | |
43 pyplearn_import_failed = True | |
44 | |
45 | |
46 | |
47 def array_columns( a, cols ): | |
48 indices = None | |
49 if isinstance( cols, int ): | |
50 indices = [ cols ] | |
51 elif isinstance( cols, slice ): | |
52 #print cols | |
53 indices = range( *cols.indices(cols.stop) ) | |
54 else: | |
55 indices = list( cols ) | |
56 | |
57 return numpy.numarray.take(a, indices, axis=1) | |
58 | |
59 def load_pmat_as_array(fname): | |
60 s = file(fname,'rb').read() | |
61 formatstr = s[0:64] | |
62 datastr = s[64:] | |
63 structuretype, l, w, data_type, endianness = formatstr.split() | |
64 | |
65 if data_type=='DOUBLE': | |
66 elemtype = 'd' | |
67 elif data_type=='FLOAT': | |
68 elemtype = 'f' | |
69 else: | |
70 raise ValueError('Invalid data type in file header: '+data_type) | |
71 | |
72 if endianness=='LITTLE_ENDIAN': | |
73 byteorder = 'little' | |
74 elif endianness=='BIG_ENDIAN': | |
75 byteorder = 'big' | |
76 else: | |
77 raise ValueError('Invalid endianness in file header: '+endianness) | |
78 | |
79 l = int(l) | |
80 w = int(w) | |
81 X = numpy.numarray.fromstring(datastr,elemtype, shape=(l,w) ) | |
82 if byteorder!=sys.byteorder: | |
83 X.byteswap(True) | |
84 return X | |
85 | |
86 def save_array_as_pmat( fname, ar, fieldnames=[] ): | |
87 s = file(fname,'wb') | |
88 | |
89 length, width = ar.shape | |
90 if fieldnames: | |
91 assert len(fieldnames) == width | |
92 metadatadir = fname+'.metadata' | |
93 if not os.path.isdir(metadatadir): | |
94 os.mkdir(metadatadir) | |
95 fieldnamefile = os.path.join(metadatadir,'fieldnames') | |
96 f = open(fieldnamefile,'wb') | |
97 for name in fieldnames: | |
98 f.write(name+'\t0\n') | |
99 f.close() | |
100 | |
101 header = 'MATRIX ' + str(length) + ' ' + str(width) + ' ' | |
102 if ar.dtype.char=='d': | |
103 header += 'DOUBLE ' | |
104 elemsize = 8 | |
105 | |
106 elif ar.dtype.char=='f': | |
107 header += 'FLOAT ' | |
108 elemsize = 4 | |
109 | |
110 else: | |
111 raise TypeError('Unsupported typecode: %s' % ar.dtype.char) | |
112 | |
113 rowsize = elemsize*width | |
114 | |
115 if sys.byteorder=='little': | |
116 header += 'LITTLE_ENDIAN ' | |
117 elif sys.byteorder=='big': | |
118 header += 'BIG_ENDIAN ' | |
119 else: | |
120 raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder)) | |
121 | |
122 header += ' '*(63-len(header))+'\n' | |
123 s.write( header ) | |
124 s.write( ar.tostring() ) | |
125 s.close() | |
126 | |
127 | |
128 ####### Iterators ########################################################### | |
129 | |
130 class VMatIt: | |
131 def __init__(self, vmat): | |
132 self.vmat = vmat | |
133 self.cur_row = 0 | |
134 | |
135 def __iter__(self): | |
136 return self | |
137 | |
138 def next(self): | |
139 if self.cur_row==self.vmat.length: | |
140 raise StopIteration | |
141 row = self.vmat.getRow(self.cur_row) | |
142 self.cur_row += 1 | |
143 return row | |
144 | |
145 class ColumnIt: | |
146 def __init__(self, vmat, col): | |
147 self.vmat = vmat | |
148 self.col = col | |
149 self.cur_row = 0 | |
150 | |
151 def __iter__(self): | |
152 return self | |
153 | |
154 def next(self): | |
155 if self.cur_row==self.vmat.length: | |
156 raise StopIteration | |
157 val = self.vmat[self.cur_row, self.col] | |
158 self.cur_row += 1 | |
159 return val | |
160 | |
161 ####### VMat classes ######################################################## | |
162 | |
163 class VMat: | |
164 def __iter__(self): | |
165 return VMatIt(self) | |
166 | |
167 def __getitem__( self, key ): | |
168 if isinstance( key, slice ): | |
169 start, stop, step = key.start, key.stop, key.step | |
170 if step!=None: | |
171 raise IndexError('Extended slice with step not currently supported') | |
172 | |
173 if start is None: | |
174 start = 0 | |
175 | |
176 l = self.length | |
177 if stop is None or stop > l: | |
178 stop = l | |
179 | |
180 return self.getRows(start,stop-start) | |
181 | |
182 elif isinstance( key, tuple ): | |
183 # Basically returns a SubVMatrix | |
184 assert len(key) == 2 | |
185 rows = self.__getitem__( key[0] ) | |
186 | |
187 shape = rows.shape | |
188 if len(shape) == 1: | |
189 return rows[ key[1] ] | |
190 | |
191 cols = key[1] | |
192 if isinstance(cols, slice): | |
193 start, stop, step = cols.start, cols.stop, cols.step | |
194 if start is None: | |
195 start = 0 | |
196 | |
197 if stop is None: | |
198 stop = self.width | |
199 elif stop < 0: | |
200 stop = self.width+stop | |
201 | |
202 cols = slice(start, stop, step) | |
203 | |
204 return array_columns(rows, cols) | |
205 | |
206 elif isinstance( key, str ): | |
207 # The key is considered to be a fieldname and a column is | |
208 # returned. | |
209 try: | |
210 return array_columns( self.getRows(0,self.length), | |
211 self.fieldnames.index(key) ) | |
212 except ValueError: | |
213 print >>sys.stderr, "Key is '%s' while fieldnames are:" % key | |
214 print >>sys.stderr, self.fieldnames | |
215 raise | |
216 | |
217 else: | |
218 if key<0: key+=self.length | |
219 return self.getRow(key) | |
220 | |
221 def getFieldIndex(self, fieldname): | |
222 try: | |
223 return self.fieldnames.index(fieldname) | |
224 except ValueError: | |
225 raise ValueError( "VMat has no field named %s. Field names: %s" | |
226 %(fieldname, ','.join(self.fieldnames)) ) | |
227 | |
228 class PMat( VMat ): | |
229 | |
230 def __init__(self, fname, openmode='r', fieldnames=[], elemtype='d', | |
231 inputsize=-1, targetsize=-1, weightsize=-1, array = None): | |
232 self.fname = fname | |
233 self.inputsize = inputsize | |
234 self.targetsize = targetsize | |
235 self.weightsize = weightsize | |
236 if openmode=='r': | |
237 self.f = open(fname,'rb') | |
238 self.read_and_parse_header() | |
239 self.load_fieldnames() | |
240 | |
241 elif openmode=='w': | |
242 self.f = open(fname,'w+b') | |
243 self.fieldnames = fieldnames | |
244 self.save_fieldnames() | |
245 self.length = 0 | |
246 self.width = len(fieldnames) | |
247 self.elemtype = elemtype | |
248 self.swap_bytes = False | |
249 self.write_header() | |
250 | |
251 elif openmode=='a': | |
252 self.f = open(fname,'r+b') | |
253 self.read_and_parse_header() | |
254 self.load_fieldnames() | |
255 | |
256 else: | |
257 raise ValueError("Currently only supported openmodes are 'r', 'w' and 'a': "+repr(openmode)+" is not supported") | |
258 | |
259 if array is not None: | |
260 shape = array.shape | |
261 if len(shape) == 1: | |
262 row_format = lambda r: [ r ] | |
263 elif len(shape) == 2: | |
264 row_format = lambda r: r | |
265 | |
266 for row in array: | |
267 self.appendRow( row_format(row) ) | |
268 | |
269 def __del__(self): | |
270 self.close() | |
271 | |
272 def write_header(self): | |
273 header = 'MATRIX ' + str(self.length) + ' ' + str(self.width) + ' ' | |
274 | |
275 if self.elemtype=='d': | |
276 header += 'DOUBLE ' | |
277 self.elemsize = 8 | |
278 elif self.elemtype=='f': | |
279 header += 'FLOAT ' | |
280 self.elemsize = 4 | |
281 else: | |
282 raise TypeError('Unsupported elemtype: '+repr(elemtype)) | |
283 self.rowsize = self.elemsize*self.width | |
284 | |
285 if sys.byteorder=='little': | |
286 header += 'LITTLE_ENDIAN ' | |
287 elif sys.byteorder=='big': | |
288 header += 'BIG_ENDIAN ' | |
289 else: | |
290 raise TypeError('Unsupported sys.byteorder: '+repr(sys.byteorder)) | |
291 | |
292 header += ' '*(63-len(header))+'\n' | |
293 | |
294 self.f.seek(0) | |
295 self.f.write(header) | |
296 | |
297 def read_and_parse_header(self): | |
298 header = self.f.read(64) | |
299 mat_type, l, w, data_type, endianness = header.split() | |
300 if mat_type!='MATRIX': | |
301 raise ValueError('Invalid file header (should start with MATRIX)') | |
302 self.length = int(l) | |
303 self.width = int(w) | |
304 if endianness=='LITTLE_ENDIAN': | |
305 byteorder = 'little' | |
306 elif endianness=='BIG_ENDIAN': | |
307 byteorder = 'big' | |
308 else: | |
309 raise ValueError('Invalid endianness in file header: '+endianness) | |
310 self.swap_bytes = (byteorder!=sys.byteorder) | |
311 | |
312 if data_type=='DOUBLE': | |
313 self.elemtype = 'd' | |
314 self.elemsize = 8 | |
315 elif data_type=='FLOAT': | |
316 self.elemtype = 'f' | |
317 self.elemsize = 4 | |
318 else: | |
319 raise ValueError('Invalid data type in file header: '+data_type) | |
320 self.rowsize = self.elemsize*self.width | |
321 | |
322 def load_fieldnames(self): | |
323 self.fieldnames = [] | |
324 fieldnamefile = os.path.join(self.fname+'.metadata','fieldnames') | |
325 if os.path.isfile(fieldnamefile): | |
326 f = open(fieldnamefile) | |
327 for row in f: | |
328 row = row.split() | |
329 if len(row)>0: | |
330 self.fieldnames.append(row[0]) | |
331 f.close() | |
332 else: | |
333 self.fieldnames = [ "field_"+str(i) for i in range(self.width) ] | |
334 | |
335 def save_fieldnames(self): | |
336 metadatadir = self.fname+'.metadata' | |
337 if not os.path.isdir(metadatadir): | |
338 os.mkdir(metadatadir) | |
339 fieldnamefile = os.path.join(metadatadir,'fieldnames') | |
340 f = open(fieldnamefile,'wb') | |
341 for name in self.fieldnames: | |
342 f.write(name+'\t0\n') | |
343 f.close() | |
344 | |
345 def getRow(self,i): | |
346 if i<0 or i>=self.length: | |
347 raise IndexError('PMat index out of range') | |
348 self.f.seek(64+i*self.rowsize) | |
349 data = self.f.read(self.rowsize) | |
350 ar = numpy.numarray.fromstring(data, self.elemtype, (self.width,)) | |
351 if self.swap_bytes: | |
352 ar.byteswap(True) | |
353 return ar | |
354 | |
355 def getRows(self,i,l): | |
356 if i<0 or l<0 or i+l>self.length: | |
357 raise IndexError('PMat index out of range') | |
358 self.f.seek(64+i*self.rowsize) | |
359 data = self.f.read(l*self.rowsize) | |
360 ar = numpy.numarray.fromstring(data, self.elemtype, (l,self.width)) | |
361 if self.swap_bytes: | |
362 ar.byteswap(True) | |
363 return ar | |
364 | |
365 def checkzerorow(self,i): | |
366 if i<0 or i>self.length: | |
367 raise IndexError('PMat index out of range') | |
368 self.f.seek(64+i*self.rowsize) | |
369 data = self.f.read(self.rowsize) | |
370 ar = numpy.numarray.fromstring(data, self.elemtype, (len(data)/self.elemsize,)) | |
371 if self.swap_bytes: | |
372 ar.byteswap(True) | |
373 for elem in ar: | |
374 if elem!=0: | |
375 return False | |
376 return True | |
377 | |
378 def putRow(self,i,row): | |
379 if i<0 or i>=self.length: | |
380 raise IndexError('PMat index out of range') | |
381 if len(row)!=self.width: | |
382 raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')') | |
383 if i<0 or i>=self.length: | |
384 raise IndexError | |
385 if self.swap_bytes: # must make a copy and swap bytes | |
386 ar = numpy.numarray.numarray(row,type=self.elemtype) | |
387 ar.byteswap(True) | |
388 else: # asarray makes a copy if not already a numarray of the right type | |
389 ar = numpy.numarray.asarray(row,type=self.elemtype) | |
390 self.f.seek(64+i*self.rowsize) | |
391 self.f.write(ar.tostring()) | |
392 | |
393 def appendRow(self,row): | |
394 if len(row)!=self.width: | |
395 raise TypeError('length of row ('+str(len(row))+ ') differs from matrix width ('+str(self.width)+')') | |
396 if self.swap_bytes: # must make a copy and swap bytes | |
397 ar = numpy.numarray.numarray(row,type=self.elemtype) | |
398 ar.byteswap(True) | |
399 else: # asarray makes a copy if not already a numarray of the right type | |
400 ar = numpy.numarray.asarray(row,type=self.elemtype) | |
401 | |
402 self.f.seek(64+self.length*self.rowsize) | |
403 self.f.write(ar.tostring()) | |
404 self.length += 1 | |
405 self.write_header() # update length in header | |
406 | |
407 def flush(self): | |
408 self.f.flush() | |
409 | |
410 def close(self): | |
411 if hasattr(self, 'f'): | |
412 self.f.close() | |
413 | |
414 def append(self,row): | |
415 self.appendRow(row) | |
416 | |
417 def __setitem__(self, i, row): | |
418 l = self.length | |
419 if i<0: i+=l | |
420 self.putRow(i,row) | |
421 | |
422 def __len__(self): | |
423 return self.length | |
424 | |
425 if not pyplearn_import_failed: | |
426 def __str__( self ): | |
427 return plearn_repr(self, indent_level=0) | |
428 | |
429 def plearn_repr( self, indent_level=0, inner_repr=plearn_repr ): | |
430 # asking for plearn_repr could be to send specification over | |
431 # to another prg so that will open the .pmat | |
432 # So we make sure data is flushed to disk. | |
433 self.flush() | |
434 | |
435 def elem_format( elem ): | |
436 k, v = elem | |
437 return '%s = %s' % ( k, inner_repr(v, indent_level+1) ) | |
438 | |
439 options = [ ( 'filename', self.fname ), | |
440 ( 'inputsize', self.inputsize ), | |
441 ( 'targetsize', self.targetsize ), | |
442 ( 'weightsize', self.weightsize ) ] | |
443 return 'FileVMatrix(%s)' % format_list_elements( options, elem_format, indent_level+1 ) | |
444 | |
445 if __name__ == '__main__': | |
446 pmat = PMat( 'tmp.pmat', 'w', fieldnames=['F1', 'F2'] ) | |
447 pmat.append( [1, 2] ) | |
448 pmat.append( [3, 4] ) | |
449 pmat.close() | |
450 | |
451 pmat = PMat( 'tmp.pmat', 'r' ) | |
452 print pmat | |
453 print pmat[:] | |
454 # print "+++ tmp.pmat contains: " | |
455 # os.system( 'plearn vmat cat tmp.pmat' ) | |
456 | |
457 os.remove( 'tmp.pmat' ) | |
458 if os.path.exists( 'tmp.pmat.metadata' ): | |
459 import shutil | |
460 shutil.rmtree( 'tmp.pmat.metadata' ) |