Mercurial > ift6266
view utils/tables_series/series.py @ 208:acb942530923
Completely rewrote my series module, now based on HDF5 and PyTables (in a separate directory called 'tables_series' for retrocompatibility of running code). Minor (inconsequential) changes to stacked_dae.
author | fsavard |
---|---|
date | Fri, 05 Mar 2010 18:07:20 -0500 |
parents | |
children | dc0d77c8a878 |
line wrap: on
line source
from tables import * import numpy ''' The way these "IsDescription constructor" work is simple: write the code as if it were in a file, then exec()ute it, leaving us with a local-scoped LocalDescription which may be used to call createTable. It's a small hack, but it's necessary as the names of the columns are retrieved based on the variable name, which we can't programmatically set otherwise. ''' def get_beginning_description_n_ints(int_names, int_width=64): int_constructor = "Int64Col" if int_width == 32: int_constructor = "Int32Col" toexec = "class LocalDescription(IsDescription):\n" pos = 0 for n in int_names: toexec += "\t" + n + " = " + int_constructor + "(pos=" + str(pos) + ")\n" return toexec def get_description_with_n_ints_n_floats(int_names, float_names, int_width=64, float_width=32): """ Constructs a class to be used when constructing a table with PyTables. This is useful to construct a series with an index with multiple levels. E.g. if you want to index your "validation error" with "epoch" first, then "minibatch_index" second, you'd use two "int_names". Parameters ---------- int_names : tuple of str Names of the int (e.g. index) columns float_names : tuple of str Names of the float (e.g. error) columns int_width : {'32', '64'} Type of ints. float_width : {'32', '64'} Type of floats. Returns ------- A class object, to pass to createTable() """ toexec = get_beginning_description_n_ints(int_names, int_width=int_width) float_constructor = "Float32Col" if float_width == 64: float_constructor = "Float64Col" pos = len(int_names) for n in float_names: toexec += "\t" + n + " = " + float_constructor + "(pos=" + str(pos) + ")\n" exec(toexec) return LocalDescription class Series(): def __init__(self, table_name, hdf5_file, index_names=('epoch',), title=None, hdf5_group='/'): """This is used as metadata in the HDF5 file to identify the series""" self.table_name = table_name self.hdf5_file = hdf5_file self.index_names = index_names self.title = title def append(self, index, element): raise NotImplementedError class ErrorSeries(Series): def __init__(self, error_name, table_name, hdf5_file, index_names=('epoch',), title=None, hdf5_group='/'): Series.__init__(self, table_name, hdf5_file, index_names, title) self.error_name = error_name table_description = self._get_table_description() self._table = hdf5_file.createTable(hdf5_group, self.table_name, table_description, title=title) def _get_table_description(self): return get_description_with_n_ints_n_floats(self.index_names, (self.error_name,)) def append(self, index, error): if len(index) != len(self.index_names): raise ValueError("index provided does not have the right length (expected " \ + str(len(self.index_names)) + " got " + str(len(index))) newrow = self._table.row for col_name, value in zip(self.index_names, index): newrow[col_name] = value newrow[self.error_name] = error newrow.append() self.hdf5_file.flush() # Does not inherit from Series because it does not itself need to # access the hdf5_file and does not need a series_name (provided # by the base_series.) class AccumulatorSeriesWrapper(): """ """ def __init__(self, base_series, reduce_every, reduce_function=numpy.mean): self.base_series = base_series self.reduce_function = reduce_function self.reduce_every = reduce_every self._buffer = [] def append(self, index, element): """ Parameters ---------- index : tuple of int The index used is the one of the last element reduced. E.g. if you accumulate over the first 1000 minibatches, the index passed to the base_series.append() function will be 1000. """ self._buffer.append(element) if len(self._buffer) == self.reduce_every: reduced = self.reduce_function(self._buffer) self.base_series.append(index, reduced) self._buffer = [] # This should never happen, except if lists # were appended, which should be a red flag. assert len(self._buffer) < self.reduce_every # Outside of class to fix an issue with exec in Python 2.6. # My sorries to the God of pretty code. def BasicStatisticsSeries_construct_table_toexec(index_names): toexec = get_beginning_description_n_ints(index_names) bpos = len(index_names) toexec += "\tmean = Float32Col(pos=" + str(bpos) + ")\n" toexec += "\tmin = Float32Col(pos=" + str(bpos+1) + ")\n" toexec += "\tmax = Float32Col(pos=" + str(bpos+2) + ")\n" toexec += "\tstd = Float32Col(pos=" + str(bpos+3) + ")\n" # This creates "LocalDescription", which we may then use exec(toexec) return LocalDescription class BasicStatisticsSeries(Series): """ Parameters ---------- series_name : str Not optional here. Will be prepended with "Basic statistics for " """ def __init__(self, table_name, hdf5_file, index_names=('epoch',), title=None, hdf5_group='/'): Series.__init__(self, table_name, hdf5_file, index_names, title) self.hdf5_group = hdf5_group self.construct_table() def construct_table(self): table_description = BasicStatisticsSeries_construct_table_toexec(self.index_names) self._table = self.hdf5_file.createTable(self.hdf5_group, self.table_name, table_description) def append(self, index, array): if len(index) != len(self.index_names): raise ValueError("index provided does not have the right length (expected " \ + str(len(self.index_names)) + " got " + str(len(index))) newrow = self._table.row for col_name, value in zip(self.index_names, index): newrow[col_name] = value newrow["mean"] = numpy.mean(array) newrow["min"] = numpy.min(array) newrow["max"] = numpy.max(array) newrow["std"] = numpy.std(array) newrow.append() self.hdf5_file.flush() class SeriesArrayWrapper(): """ Simply redistributes any number of elements to sub-series to respective append()s. """ def __init__(self, base_series_list): self.base_series_list = base_series_list def append(self, index, elements): if len(elements) != len(self.base_series_list): raise ValueError("not enough or too much elements provided (expected " \ + str(len(self.base_series_list)) + " got " + str(len(elements))) for series, el in zip(self.base_series_list, elements): series.append(index, el) class ParamsStatisticsWrapper(SeriesArrayWrapper): def __init__(self, arrays_names, new_group_name, hdf5_file, base_group='/', index_names=('epoch',), title=""): base_series_list = [] new_group = hdf5_file.createGroup(base_group, new_group_name, title=title) for name in arrays_names: base_series_list.append( BasicStatisticsSeries( table_name=name, hdf5_file=hdf5_file, index_names=('epoch','minibatch'), hdf5_group=new_group._v_pathname)) SeriesArrayWrapper.__init__(self, base_series_list)