comparison amat.py @ 266:6e69fb91f3c0

initial commit of amat
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 04 Jun 2008 17:49:09 -0400
parents
children bd937e845bbb
comparison
equal deleted inserted replaced
263:5614b186c5f4 266:6e69fb91f3c0
1 """load PLearn AMat files"""
2
3 import sys, numpy, array
4
5 path_MNIST = '/u/bergstrj/pub/data/mnist.amat'
6
7
8 class AMat:
9 """DataSource to access a plearn amat file as a periodic unrandomized stream.
10
11 Attributes:
12
13 input -- minibatch of input
14 target -- minibatch of target
15 weight -- minibatch of weight
16 extra -- minitbatch of extra
17
18 all -- the entire data contents of the amat file
19 n_examples -- the number of training examples in the file
20
21 AMat stands for Ascii Matri[x,ces]
22
23 """
24
25 marker_size = '#size:'
26 marker_sizes = '#sizes:'
27 marker_col_names = '#:'
28
29 def __init__(self, path, head=None, update_interval=0, ofile=sys.stdout):
30
31 """Load the amat at <path> into memory.
32
33 path - str: location of amat file
34 head - int: stop reading after this many data rows
35 update_interval - int: print '.' to ofile every <this many> lines
36 ofile - file: print status, msgs, etc. to this file
37
38 """
39 self.all = None
40 self.input = None
41 self.target = None
42 self.weight = None
43 self.extra = None
44
45 self.header = False
46 self.header_size = None
47 self.header_rows = None
48 self.header_cols = None
49 self.header_sizes = None
50 self.header_col_names = []
51
52 data_started = False
53 data = array.array('d')
54
55 f = open(path)
56 n_data_lines = 0
57 len_float_line = None
58
59 for i,line in enumerate(f):
60 if n_data_lines == head:
61 #we've read enough data,
62 # break even if there's more in the file
63 break
64 if len(line) == 0 or line == '\n':
65 continue
66 if line[0] == '#':
67 if not data_started:
68 #the condition means that the file has a header, and we're on
69 # some header line
70 self.header = True
71 if line.startswith(AMat.marker_size):
72 info = line[len(AMat.marker_size):]
73 self.header_size = [int(s) for s in info.split()]
74 self.header_rows, self.header_cols = self.header_size
75 if line.startswith(AMat.marker_col_names):
76 info = line[len(AMat.marker_col_names):]
77 self.header_col_names = info.split()
78 elif line.startswith(AMat.marker_sizes):
79 info = line[len(AMat.marker_sizes):]
80 self.header_sizes = [int(s) for s in info.split()]
81 else:
82 #the first non-commented line tells us that the header is done
83 data_started = True
84 float_line = [float(s) for s in line.split()]
85 if len_float_line is None:
86 len_float_line = len(float_line)
87 if (self.header_cols is not None) \
88 and self.header_cols != len_float_line:
89 print >> sys.stderr, \
90 'WARNING: header declared %i cols but first line has %i, using %i',\
91 self.header_cols, len_float_line, len_float_line
92 else:
93 if len_float_line != len(float_line):
94 raise IOError('wrong line length', i, line)
95 data.extend(float_line)
96 n_data_lines += 1
97
98 if update_interval > 0 and (ofile is not None) \
99 and n_data_lines % update_interval == 0:
100 ofile.write('.')
101 ofile.flush()
102
103 if update_interval > 0:
104 ofile.write('\n')
105 f.close()
106
107 # convert from array.array to numpy.ndarray
108 nshape = (len(data) / len_float_line, len_float_line)
109 self.all = numpy.frombuffer(data).reshape(nshape)
110 self.n_examples = self.all.shape[0]
111
112 # assign
113 if self.header_sizes is not None:
114 if len(self.header_sizes) > 4:
115 print >> sys.stderr, 'WARNING: ignoring sizes after 4th in %s' % path
116 leftmost = 0
117 #here we make use of the fact that if header_sizes has len < 4
118 # the loop will exit before 4 iterations
119 attrlist = ['input', 'target', 'weight', 'extra']
120 for attr, ncols in zip(attrlist, self.header_sizes):
121 setattr(self, attr, self.all[:, leftmost:leftmost+ncols])
122 leftmost += ncols
123