Mercurial > pylearn
diff amat.py @ 266:6e69fb91f3c0
initial commit of amat
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 04 Jun 2008 17:49:09 -0400 |
parents | |
children | bd937e845bbb |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/amat.py Wed Jun 04 17:49:09 2008 -0400 @@ -0,0 +1,123 @@ +"""load PLearn AMat files""" + +import sys, numpy, array + +path_MNIST = '/u/bergstrj/pub/data/mnist.amat' + + +class AMat: + """DataSource to access a plearn amat file as a periodic unrandomized stream. + + Attributes: + + input -- minibatch of input + target -- minibatch of target + weight -- minibatch of weight + extra -- minitbatch of extra + + all -- the entire data contents of the amat file + n_examples -- the number of training examples in the file + + AMat stands for Ascii Matri[x,ces] + + """ + + marker_size = '#size:' + marker_sizes = '#sizes:' + marker_col_names = '#:' + + def __init__(self, path, head=None, update_interval=0, ofile=sys.stdout): + + """Load the amat at <path> into memory. + + path - str: location of amat file + head - int: stop reading after this many data rows + update_interval - int: print '.' to ofile every <this many> lines + ofile - file: print status, msgs, etc. to this file + + """ + self.all = None + self.input = None + self.target = None + self.weight = None + self.extra = None + + self.header = False + self.header_size = None + self.header_rows = None + self.header_cols = None + self.header_sizes = None + self.header_col_names = [] + + data_started = False + data = array.array('d') + + f = open(path) + n_data_lines = 0 + len_float_line = None + + for i,line in enumerate(f): + if n_data_lines == head: + #we've read enough data, + # break even if there's more in the file + break + if len(line) == 0 or line == '\n': + continue + if line[0] == '#': + if not data_started: + #the condition means that the file has a header, and we're on + # some header line + self.header = True + if line.startswith(AMat.marker_size): + info = line[len(AMat.marker_size):] + self.header_size = [int(s) for s in info.split()] + self.header_rows, self.header_cols = self.header_size + if line.startswith(AMat.marker_col_names): + info = line[len(AMat.marker_col_names):] + self.header_col_names = info.split() + elif line.startswith(AMat.marker_sizes): + info = line[len(AMat.marker_sizes):] + self.header_sizes = [int(s) for s in info.split()] + else: + #the first non-commented line tells us that the header is done + data_started = True + float_line = [float(s) for s in line.split()] + if len_float_line is None: + len_float_line = len(float_line) + if (self.header_cols is not None) \ + and self.header_cols != len_float_line: + print >> sys.stderr, \ + 'WARNING: header declared %i cols but first line has %i, using %i',\ + self.header_cols, len_float_line, len_float_line + else: + if len_float_line != len(float_line): + raise IOError('wrong line length', i, line) + data.extend(float_line) + n_data_lines += 1 + + if update_interval > 0 and (ofile is not None) \ + and n_data_lines % update_interval == 0: + ofile.write('.') + ofile.flush() + + if update_interval > 0: + ofile.write('\n') + f.close() + + # convert from array.array to numpy.ndarray + nshape = (len(data) / len_float_line, len_float_line) + self.all = numpy.frombuffer(data).reshape(nshape) + self.n_examples = self.all.shape[0] + + # assign + if self.header_sizes is not None: + if len(self.header_sizes) > 4: + print >> sys.stderr, 'WARNING: ignoring sizes after 4th in %s' % path + leftmost = 0 + #here we make use of the fact that if header_sizes has len < 4 + # the loop will exit before 4 iterations + attrlist = ['input', 'target', 'weight', 'extra'] + for attr, ncols in zip(attrlist, self.header_sizes): + setattr(self, attr, self.all[:, leftmost:leftmost+ncols]) + leftmost += ncols +