Mercurial > pylearn
comparison amat.py @ 266:6e69fb91f3c0
initial commit of amat
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 04 Jun 2008 17:49:09 -0400 |
parents | |
children | bd937e845bbb |
comparison
equal
deleted
inserted
replaced
263:5614b186c5f4 | 266:6e69fb91f3c0 |
---|---|
1 """load PLearn AMat files""" | |
2 | |
3 import sys, numpy, array | |
4 | |
5 path_MNIST = '/u/bergstrj/pub/data/mnist.amat' | |
6 | |
7 | |
8 class AMat: | |
9 """DataSource to access a plearn amat file as a periodic unrandomized stream. | |
10 | |
11 Attributes: | |
12 | |
13 input -- minibatch of input | |
14 target -- minibatch of target | |
15 weight -- minibatch of weight | |
16 extra -- minitbatch of extra | |
17 | |
18 all -- the entire data contents of the amat file | |
19 n_examples -- the number of training examples in the file | |
20 | |
21 AMat stands for Ascii Matri[x,ces] | |
22 | |
23 """ | |
24 | |
25 marker_size = '#size:' | |
26 marker_sizes = '#sizes:' | |
27 marker_col_names = '#:' | |
28 | |
29 def __init__(self, path, head=None, update_interval=0, ofile=sys.stdout): | |
30 | |
31 """Load the amat at <path> into memory. | |
32 | |
33 path - str: location of amat file | |
34 head - int: stop reading after this many data rows | |
35 update_interval - int: print '.' to ofile every <this many> lines | |
36 ofile - file: print status, msgs, etc. to this file | |
37 | |
38 """ | |
39 self.all = None | |
40 self.input = None | |
41 self.target = None | |
42 self.weight = None | |
43 self.extra = None | |
44 | |
45 self.header = False | |
46 self.header_size = None | |
47 self.header_rows = None | |
48 self.header_cols = None | |
49 self.header_sizes = None | |
50 self.header_col_names = [] | |
51 | |
52 data_started = False | |
53 data = array.array('d') | |
54 | |
55 f = open(path) | |
56 n_data_lines = 0 | |
57 len_float_line = None | |
58 | |
59 for i,line in enumerate(f): | |
60 if n_data_lines == head: | |
61 #we've read enough data, | |
62 # break even if there's more in the file | |
63 break | |
64 if len(line) == 0 or line == '\n': | |
65 continue | |
66 if line[0] == '#': | |
67 if not data_started: | |
68 #the condition means that the file has a header, and we're on | |
69 # some header line | |
70 self.header = True | |
71 if line.startswith(AMat.marker_size): | |
72 info = line[len(AMat.marker_size):] | |
73 self.header_size = [int(s) for s in info.split()] | |
74 self.header_rows, self.header_cols = self.header_size | |
75 if line.startswith(AMat.marker_col_names): | |
76 info = line[len(AMat.marker_col_names):] | |
77 self.header_col_names = info.split() | |
78 elif line.startswith(AMat.marker_sizes): | |
79 info = line[len(AMat.marker_sizes):] | |
80 self.header_sizes = [int(s) for s in info.split()] | |
81 else: | |
82 #the first non-commented line tells us that the header is done | |
83 data_started = True | |
84 float_line = [float(s) for s in line.split()] | |
85 if len_float_line is None: | |
86 len_float_line = len(float_line) | |
87 if (self.header_cols is not None) \ | |
88 and self.header_cols != len_float_line: | |
89 print >> sys.stderr, \ | |
90 'WARNING: header declared %i cols but first line has %i, using %i',\ | |
91 self.header_cols, len_float_line, len_float_line | |
92 else: | |
93 if len_float_line != len(float_line): | |
94 raise IOError('wrong line length', i, line) | |
95 data.extend(float_line) | |
96 n_data_lines += 1 | |
97 | |
98 if update_interval > 0 and (ofile is not None) \ | |
99 and n_data_lines % update_interval == 0: | |
100 ofile.write('.') | |
101 ofile.flush() | |
102 | |
103 if update_interval > 0: | |
104 ofile.write('\n') | |
105 f.close() | |
106 | |
107 # convert from array.array to numpy.ndarray | |
108 nshape = (len(data) / len_float_line, len_float_line) | |
109 self.all = numpy.frombuffer(data).reshape(nshape) | |
110 self.n_examples = self.all.shape[0] | |
111 | |
112 # assign | |
113 if self.header_sizes is not None: | |
114 if len(self.header_sizes) > 4: | |
115 print >> sys.stderr, 'WARNING: ignoring sizes after 4th in %s' % path | |
116 leftmost = 0 | |
117 #here we make use of the fact that if header_sizes has len < 4 | |
118 # the loop will exit before 4 iterations | |
119 attrlist = ['input', 'target', 'weight', 'extra'] | |
120 for attr, ncols in zip(attrlist, self.header_sizes): | |
121 setattr(self, attr, self.all[:, leftmost:leftmost+ncols]) | |
122 leftmost += ncols | |
123 |