Mercurial > pylearn
changeset 910:8837535006f1
Added SeriesTables module (which I had placed originally in the IFT6266 repository) and its doc. Added a logo for the doc and changed conf.py to reflect correct title.
author | fsavard |
---|---|
date | Thu, 18 Mar 2010 12:23:14 -0400 |
parents | 8e3f1d852ab1 |
children | fdb63e4e042d |
files | doc/conf.py doc/index.txt pylearn/io/seriestables/__init__.py pylearn/io/seriestables/series.py pylearn/io/seriestables/test_series.py |
diffstat | 5 files changed, 924 insertions(+), 5 deletions(-) [+] |
line wrap: on
line diff
--- a/doc/conf.py Thu Mar 18 11:34:29 2010 -0400 +++ b/doc/conf.py Thu Mar 18 12:23:14 2010 -0400 @@ -44,7 +44,7 @@ master_doc = 'index' # General substitutions. -project = 'Theano' +project = 'Pylearn' copyright = '2008--2009, LISA lab' # The default replacements for |version| and |release|, also used in various @@ -66,7 +66,7 @@ # List of directories, relative to source directories, that shouldn't be searched # for source files. -exclude_dirs = ['images', 'scripts', 'sandbox'] +exclude_dirs = ['images', 'scripts', 'api'] # The reST default role (used for this markup: `text`) to use for all documents. #default_role = None @@ -105,7 +105,7 @@ # The name of an image file (within the static path) to place at the top of # the sidebar. #html_logo = 'images/theano_logo-200x67.png' -html_logo = 'images/theano_logo_allblue_200x46.png' +html_logo = 'images/logo_pylearn_200x57.png' # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -153,7 +153,7 @@ #html_file_suffix = '' # Output file base name for HTML help builder. -htmlhelp_basename = 'theanodoc' +htmlhelp_basename = 'pylearndoc' # Options for LaTeX output @@ -168,7 +168,7 @@ # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, document class [howto/manual]). latex_documents = [ - ('index', 'theano.tex', 'theano Documentation', + ('index', 'pylearn.tex', 'pylearn Documentation', 'LISA lab, University of Montreal', 'manual'), ]
--- a/doc/index.txt Thu Mar 18 11:34:29 2010 -0400 +++ b/doc/index.txt Thu Mar 18 12:23:14 2010 -0400 @@ -24,6 +24,7 @@ For the moment, the following documentation is available. +* :doc:`io.SeriesTables module <seriestables>` -- Saves error series and other statistics during training * `API <api/>`_ -- The automatically-generated API documentation You can download the latest `PDF documentation <http://deeplearning.net/software/pylearn/pylearn.pdf>`_, rather than reading it online.
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/io/seriestables/__init__.py Thu Mar 18 12:23:14 2010 -0400 @@ -0,0 +1,2 @@ +from series import ErrorSeries, BasicStatisticsSeries, AccumulatorSeriesWrapper, SeriesArrayWrapper, SharedParamsStatisticsWrapper, DummySeries +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/io/seriestables/series.py Thu Mar 18 12:23:14 2010 -0400 @@ -0,0 +1,605 @@ +import tables + +import numpy +import time + +############################################################################## +# Utility functions to create IsDescription objects (pytables data types) + +''' +The way these "IsDescription constructor" work is simple: write the +code as if it were in a file, then exec()ute it, leaving us with +a local-scoped LocalDescription which may be used to call createTable. + +It's a small hack, but it's necessary as the names of the columns +are retrieved based on the variable name, which we can't programmatically set +otherwise. +''' + +def _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock, pos=0): + toexec = "" + + if store_timestamp: + toexec += "\ttimestamp = tables.Time32Col(pos="+str(pos)+")\n" + pos += 1 + + if store_cpuclock: + toexec += "\tcpuclock = tables.Float64Col(pos="+str(pos)+")\n" + pos += 1 + + return toexec, pos + +def _get_description_n_ints(int_names, int_width=64, pos=0): + """ + Begins construction of a class inheriting from IsDescription + to construct an HDF5 table with index columns named with int_names. + + See Series().__init__ to see how those are used. + """ + int_constructor = "tables.Int64Col" + if int_width == 32: + int_constructor = "tables.Int32Col" + elif not int_width in (32, 64): + raise "int_width must be left unspecified, or should equal 32 or 64" + + toexec = "" + + for n in int_names: + toexec += "\t" + n + " = " + int_constructor + "(pos=" + str(pos) + ")\n" + pos += 1 + + return toexec, pos + +def _get_description_with_n_ints_n_floats(int_names, float_names, + int_width=64, float_width=32, + store_timestamp=True, store_cpuclock=True): + """ + Constructs a class to be used when constructing a table with PyTables. + + This is useful to construct a series with an index with multiple levels. + E.g. if you want to index your "validation error" with "epoch" first, then + "minibatch_index" second, you'd use two "int_names". + + Parameters + ---------- + int_names : tuple of str + Names of the int (e.g. index) columns + float_names : tuple of str + Names of the float (e.g. error) columns + int_width : {'32', '64'} + Type of ints. + float_width : {'32', '64'} + Type of floats. + store_timestamp : bool + See __init__ of Series + store_cpuclock : bool + See __init__ of Series + + Returns + ------- + A class object, to pass to createTable() + """ + + toexec = "class LocalDescription(tables.IsDescription):\n" + + toexec_, pos = _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock) + toexec += toexec_ + + toexec_, pos = _get_description_n_ints(int_names, int_width=int_width, pos=pos) + toexec += toexec_ + + float_constructor = "tables.Float32Col" + if float_width == 64: + float_constructor = "tables.Float64Col" + elif not float_width in (32, 64): + raise "float_width must be left unspecified, or should equal 32 or 64" + + for n in float_names: + toexec += "\t" + n + " = " + float_constructor + "(pos=" + str(pos) + ")\n" + pos += 1 + + exec(toexec) + + return LocalDescription + +############################################################################## +# Series classes + +# Shortcut to allow passing a single int as index, instead of a tuple +def _index_to_tuple(index): + if type(index) == tuple: + return index + + if type(index) == list: + index = tuple(index) + return index + + try: + if index % 1 > 0.001 and index % 1 < 0.999: + raise + idx = long(index) + return (idx,) + except: + raise TypeError("index must be a tuple of integers, or at least a single integer") + +class Series(): + """ + Base Series class, with minimal arguments and type checks. + + Yet cannot be used by itself (it's append() method raises an error) + """ + + def __init__(self, table_name, hdf5_file, index_names=('epoch',), + title="", hdf5_group='/', + store_timestamp=True, store_cpuclock=True): + """Basic arguments each Series must get. + + Parameters + ---------- + table_name : str + Name of the table to create under group "hd5_group" (other + parameter). No spaces, ie. follow variable naming restrictions. + hdf5_file : open HDF5 file + File opened with openFile() in PyTables (ie. return value of + openFile). + index_names : tuple of str + Columns to use as index for elements in the series, other + example would be ('epoch', 'minibatch'). This would then allow + you to call append(index, element) with index made of two ints, + one for epoch index, one for minibatch index in epoch. + title : str + Title to attach to this table as metadata. Can contain spaces + and be longer then the table_name. + hdf5_group : str + Path of the group (kind of a file) in the HDF5 file under which + to create the table. + store_timestamp : bool + Whether to create a column for timestamps and store them with + each record. + store_cpuclock : bool + Whether to create a column for cpu clock and store it with + each record. + """ + + ######################################### + # checks + + if type(table_name) != str: + raise TypeError("table_name must be a string") + if table_name == "": + raise ValueError("table_name must not be empty") + + if not isinstance(hdf5_file, tables.file.File): + raise TypeError("hdf5_file must be an open HDF5 file (use tables.openFile)") + #if not ('w' in hdf5_file.mode or 'a' in hdf5_file.mode): + # raise ValueError("hdf5_file must be opened in write or append mode") + + if type(index_names) != tuple: + raise TypeError("index_names must be a tuple of strings." + \ + "If you have only one element in the tuple, don't forget " +\ + "to add a comma, e.g. ('epoch',).") + for name in index_names: + if type(name) != str: + raise TypeError("index_names must only contain strings, but also"+\ + "contains a "+str(type(name))+".") + + if type(title) != str: + raise TypeError("title must be a string, even if empty") + + if type(hdf5_group) != str: + raise TypeError("hdf5_group must be a string") + + if type(store_timestamp) != bool: + raise TypeError("store_timestamp must be a bool") + + if type(store_cpuclock) != bool: + raise TypeError("store_timestamp must be a bool") + + ######################################### + + self.table_name = table_name + self.hdf5_file = hdf5_file + self.index_names = index_names + self.title = title + self.hdf5_group = hdf5_group + + self.store_timestamp = store_timestamp + self.store_cpuclock = store_cpuclock + + def append(self, index, element): + raise NotImplementedError + + def _timestamp_cpuclock(self, newrow): + if self.store_timestamp: + newrow["timestamp"] = time.time() + + if self.store_cpuclock: + newrow["cpuclock"] = time.clock() + +class DummySeries(): + """ + To put in a series dictionary instead of a real series, to do nothing + when we don't want a given series to be saved. + + E.g. if we'd normally have a "training_error" series in a dictionary + of series, the training loop would have something like this somewhere: + + series["training_error"].append((15,), 20.0) + + but if we don't want to save the training errors this time, we simply + do + + series["training_error"] = DummySeries() + """ + def append(self, index, element): + pass + +class ErrorSeries(Series): + """ + Most basic Series: saves a single float (called an Error as this is + the most common use case I foresee) along with an index (epoch, for + example) and timestamp/cpu.clock for each of these floats. + """ + + def __init__(self, error_name, table_name, + hdf5_file, index_names=('epoch',), + title="", hdf5_group='/', + store_timestamp=True, store_cpuclock=True): + """ + For most parameters, see Series.__init__ + + Parameters + ---------- + error_name : str + In the HDF5 table, column name for the error float itself. + """ + + # most type/value checks are performed in Series.__init__ + Series.__init__(self, table_name, hdf5_file, index_names, title, + store_timestamp=store_timestamp, + store_cpuclock=store_cpuclock) + + if type(error_name) != str: + raise TypeError("error_name must be a string") + if error_name == "": + raise ValueError("error_name must not be empty") + + self.error_name = error_name + + self._create_table() + + def _create_table(self): + table_description = _get_description_with_n_ints_n_floats( \ + self.index_names, (self.error_name,), + store_timestamp=self.store_timestamp, + store_cpuclock=self.store_cpuclock) + + self._table = self.hdf5_file.createTable(self.hdf5_group, + self.table_name, + table_description, + title=self.title) + + + def append(self, index, error): + """ + Parameters + ---------- + index : tuple of int + Following index_names passed to __init__, e.g. (12, 15) if + index_names were ('epoch', 'minibatch_size'). + A single int (not tuple) is acceptable if index_names has a single + element. + An array will be casted to a tuple, as a convenience. + + error : float + Next error in the series. + """ + index = _index_to_tuple(index) + + if len(index) != len(self.index_names): + raise ValueError("index provided does not have the right length (expected " \ + + str(len(self.index_names)) + " got " + str(len(index))) + + # other checks are implicit when calling newrow[..] =, + # which should throw an error if not of the right type + + newrow = self._table.row + + # Columns for index in table are based on index_names + for col_name, value in zip(self.index_names, index): + newrow[col_name] = value + newrow[self.error_name] = error + + # adds timestamp and cpuclock to newrow if necessary + self._timestamp_cpuclock(newrow) + + newrow.append() + + self.hdf5_file.flush() + +# Does not inherit from Series because it does not itself need to +# access the hdf5_file and does not need a series_name (provided +# by the base_series.) +class AccumulatorSeriesWrapper(): + ''' + Wraps a Series by accumulating objects passed its Accumulator.append() + method and "reducing" (e.g. calling numpy.mean(list)) once in a while, + every "reduce_every" calls in fact. + ''' + + def __init__(self, base_series, reduce_every, reduce_function=numpy.mean): + """ + Parameters + ---------- + base_series : Series + This object must have an append(index, value) function. + + reduce_every : int + Apply the reduction function (e.g. mean()) every time we get this + number of elements. E.g. if this is 100, then every 100 numbers + passed to append(), we'll take the mean and call append(this_mean) + on the BaseSeries. + + reduce_function : function + Must take as input an array of "elements", as passed to (this + accumulator's) append(). Basic case would be to take an array of + floats and sum them into one float, for example. + """ + self.base_series = base_series + self.reduce_function = reduce_function + self.reduce_every = reduce_every + + self._buffer = [] + + + def append(self, index, element): + """ + Parameters + ---------- + index : tuple of int + The index used is the one of the last element reduced. E.g. if + you accumulate over the first 1000 minibatches, the index + passed to the base_series.append() function will be 1000. + A single int (not tuple) is acceptable if index_names has a single + element. + An array will be casted to a tuple, as a convenience. + + element : float + Element that will be accumulated. + """ + self._buffer.append(element) + + if len(self._buffer) == self.reduce_every: + reduced = self.reduce_function(self._buffer) + self.base_series.append(index, reduced) + self._buffer = [] + + # The >= case should never happen, except if lists + # were appended by accessing _buffer externally (when it's + # intended to be private), which should be a red flag. + assert len(self._buffer) < self.reduce_every + +# Outside of class to fix an issue with exec in Python 2.6. +# My sorries to the god of pretty code. +def _BasicStatisticsSeries_construct_table_toexec(index_names, store_timestamp, store_cpuclock): + toexec = "class LocalDescription(tables.IsDescription):\n" + + toexec_, pos = _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock) + toexec += toexec_ + + toexec_, pos = _get_description_n_ints(index_names, pos=pos) + toexec += toexec_ + + toexec += "\tmean = tables.Float32Col(pos=" + str(pos) + ")\n" + toexec += "\tmin = tables.Float32Col(pos=" + str(pos+1) + ")\n" + toexec += "\tmax = tables.Float32Col(pos=" + str(pos+2) + ")\n" + toexec += "\tstd = tables.Float32Col(pos=" + str(pos+3) + ")\n" + + # This creates "LocalDescription", which we may then use + exec(toexec) + + return LocalDescription + +# Defaults functions for BasicStatsSeries. These can be replaced. +_basic_stats_functions = {'mean': lambda(x): numpy.mean(x), + 'min': lambda(x): numpy.min(x), + 'max': lambda(x): numpy.max(x), + 'std': lambda(x): numpy.std(x)} + +class BasicStatisticsSeries(Series): + + def __init__(self, table_name, hdf5_file, + stats_functions=_basic_stats_functions, + index_names=('epoch',), title="", hdf5_group='/', + store_timestamp=True, store_cpuclock=True): + """ + For most parameters, see Series.__init__ + + Parameters + ---------- + series_name : str + Not optional here. Will be prepended with "Basic statistics for " + + stats_functions : dict, optional + Dictionary with a function for each key "mean", "min", "max", + "std". The function must take whatever is passed to append(...) + and return a single number (float). + """ + + # Most type/value checks performed in Series.__init__ + Series.__init__(self, table_name, hdf5_file, index_names, title, + store_timestamp=store_timestamp, + store_cpuclock=store_cpuclock) + + if type(hdf5_group) != str: + raise TypeError("hdf5_group must be a string") + + if type(stats_functions) != dict: + # just a basic check. We'll suppose caller knows what he's doing. + raise TypeError("stats_functions must be a dict") + + self.hdf5_group = hdf5_group + + self.stats_functions = stats_functions + + self._create_table() + + def _create_table(self): + table_description = \ + _BasicStatisticsSeries_construct_table_toexec( \ + self.index_names, + self.store_timestamp, self.store_cpuclock) + + self._table = self.hdf5_file.createTable(self.hdf5_group, + self.table_name, table_description) + + def append(self, index, array): + """ + Parameters + ---------- + index : tuple of int + Following index_names passed to __init__, e.g. (12, 15) + if index_names were ('epoch', 'minibatch_size') + A single int (not tuple) is acceptable if index_names has a single + element. + An array will be casted to a tuple, as a convenience. + + array + Is of whatever type the stats_functions passed to + __init__ can take. Default is anything numpy.mean(), + min(), max(), std() can take. + """ + index = _index_to_tuple(index) + + if len(index) != len(self.index_names): + raise ValueError("index provided does not have the right length (expected " \ + + str(len(self.index_names)) + " got " + str(len(index))) + + newrow = self._table.row + + for col_name, value in zip(self.index_names, index): + newrow[col_name] = value + + newrow["mean"] = self.stats_functions['mean'](array) + newrow["min"] = self.stats_functions['min'](array) + newrow["max"] = self.stats_functions['max'](array) + newrow["std"] = self.stats_functions['std'](array) + + self._timestamp_cpuclock(newrow) + + newrow.append() + + self.hdf5_file.flush() + +class SeriesArrayWrapper(): + """ + Simply redistributes any number of elements to sub-series to respective + append()s. + + To use if you have many elements to append in similar series, e.g. if you + have an array containing [train_error, valid_error, test_error], and 3 + corresponding series, this allows you to simply pass this array of 3 + values to append() instead of passing each element to each individual + series in turn. + """ + + def __init__(self, base_series_list): + """ + Parameters + ---------- + base_series_list : array or tuple of Series + You must have previously created and configured each of those + series, then put them in an array. This array must follow the + same order as the array passed as ``elements`` parameter of + append(). + """ + self.base_series_list = base_series_list + + def append(self, index, elements): + """ + Parameters + ---------- + index : tuple of int + See for example ErrorSeries.append() + + elements : array or tuple + Array or tuple of elements that will be passed down to + the base_series passed to __init__, in the same order. + """ + if len(elements) != len(self.base_series_list): + raise ValueError("not enough or too much elements provided (expected " \ + + str(len(self.base_series_list)) + " got " + str(len(elements))) + + for series, el in zip(self.base_series_list, elements): + series.append(index, el) + +class SharedParamsStatisticsWrapper(SeriesArrayWrapper): + ''' + Save mean, min/max, std of shared parameters place in an array. + + Here "shared" means "theano.shared", which means elements of the + array will have a .value to use for numpy.mean(), etc. + + This inherits from SeriesArrayWrapper, which provides the append() + method. + ''' + + def __init__(self, arrays_names, new_group_name, hdf5_file, + base_group='/', index_names=('epoch',), title="", + store_timestamp=True, store_cpuclock=True): + """ + For other parameters, see Series.__init__ + + Parameters + ---------- + array_names : array or tuple of str + Name of each array, in order of the array passed to append(). E.g. + ('layer1_b', 'layer1_W', 'layer2_b', 'layer2_W') + + new_group_name : str + Name of a new HDF5 group which will be created under base_group to + store the new series. + + base_group : str + Path of the group under which to create the new group which will + store the series. + + title : str + Here the title is attached to the new group, not a table. + + store_timestamp : bool + Here timestamp and cpuclock are stored in *each* table + + store_cpuclock : bool + Here timestamp and cpuclock are stored in *each* table + """ + + # most other checks done when calling BasicStatisticsSeries + if type(new_group_name) != str: + raise TypeError("new_group_name must be a string") + if new_group_name == "": + raise ValueError("new_group_name must not be empty") + + base_series_list = [] + + new_group = hdf5_file.createGroup(base_group, new_group_name, title=title) + + stats_functions = {'mean': lambda(x): numpy.mean(x.value), + 'min': lambda(x): numpy.min(x.value), + 'max': lambda(x): numpy.max(x.value), + 'std': lambda(x): numpy.std(x.value)} + + for name in arrays_names: + base_series_list.append( + BasicStatisticsSeries( + table_name=name, + hdf5_file=hdf5_file, + index_names=index_names, + stats_functions=stats_functions, + hdf5_group=new_group._v_pathname, + store_timestamp=store_timestamp, + store_cpuclock=store_cpuclock)) + + SeriesArrayWrapper.__init__(self, base_series_list) + +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/io/seriestables/test_series.py Thu Mar 18 12:23:14 2010 -0400 @@ -0,0 +1,311 @@ +import tempfile + +import numpy +import numpy.random + +from jobman import DD + +import tables + +from series import * +import series + +################################################# +# Utils + +def compare_floats(f1,f2): + if f1-f2 < 1e-3: + return True + return False + +def compare_lists(it1, it2, floats=False): + if len(it1) != len(it2): + return False + + for el1, el2 in zip(it1, it2): + if floats: + if not compare_floats(el1,el2): + return False + elif el1 != el2: + return False + + return True + +################################################# +# Basic Series class tests + +def test_Series_types(): + pass + +################################################# +# ErrorSeries tests + +def test_ErrorSeries_common_case(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Validation error indexed by epoch and minibatch") + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append((1,1), 32.0) + validation_error.append((1,2), 30.0) + validation_error.append((2,1), 28.0) + validation_error.append((2,2), 26.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + assert compare_lists(table.cols.epoch[:], [1,1,2,2]) + assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) + assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) + +def test_ErrorSeries_no_index(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_error = series.ErrorSeries(error_name="validation_error", + table_name="validation_error", + hdf5_file=h5f, + # empty tuple + index_names=tuple(), + title="Validation error with no index") + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append(tuple(), 32.0) + validation_error.append(tuple(), 30.0) + validation_error.append(tuple(), 28.0) + validation_error.append(tuple(), 26.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) + assert not ("epoch" in dir(table.cols)) + +def test_ErrorSeries_notimestamp(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Validation error indexed by epoch and minibatch", + store_timestamp=False) + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append((1,1), 32.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + assert compare_lists(table.cols.epoch[:], [1]) + assert not ("timestamp" in dir(table.cols)) + assert "cpuclock" in dir(table.cols) + +def test_ErrorSeries_nocpuclock(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Validation error indexed by epoch and minibatch", + store_cpuclock=False) + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append((1,1), 32.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + assert compare_lists(table.cols.epoch[:], [1]) + assert not ("cpuclock" in dir(table.cols)) + assert "timestamp" in dir(table.cols) + +def test_AccumulatorSeriesWrapper_common_case(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_error = ErrorSeries(error_name="accumulated_validation_error", + table_name="accumulated_validation_error", + hdf5_file=h5f, + index_names=('epoch','minibatch'), + title="Validation error, summed every 3 minibatches, indexed by epoch and minibatch") + + accumulator = AccumulatorSeriesWrapper(base_series=validation_error, + reduce_every=3, reduce_function=numpy.sum) + + # (1,1), (1,2) etc. are (epoch, minibatch) index + accumulator.append((1,1), 32.0) + accumulator.append((1,2), 30.0) + accumulator.append((2,1), 28.0) + accumulator.append((2,2), 26.0) + accumulator.append((3,1), 24.0) + accumulator.append((3,2), 22.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'accumulated_validation_error') + + assert compare_lists(table.cols.epoch[:], [2,3]) + assert compare_lists(table.cols.minibatch[:], [1,2]) + assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True) + +def test_BasicStatisticsSeries_common_case(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + stats_series = BasicStatisticsSeries(table_name="b_vector_statistics", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Basic statistics for b vector indexed by epoch and minibatch") + + # (1,1), (1,2) etc. are (epoch, minibatch) index + stats_series.append((1,1), [0.15, 0.20, 0.30]) + stats_series.append((1,2), [-0.18, 0.30, 0.58]) + stats_series.append((2,1), [0.18, -0.38, -0.68]) + stats_series.append((2,2), [0.15, 0.02, 1.9]) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'b_vector_statistics') + + assert compare_lists(table.cols.epoch[:], [1,1,2,2]) + assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) + assert compare_lists(table.cols.mean[:], [0.21666667, 0.23333333, -0.29333332, 0.69], floats=True) + assert compare_lists(table.cols.min[:], [0.15000001, -0.18000001, -0.68000001, 0.02], floats=True) + assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True) + assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939, 0.35640177, 0.85724366], floats=True) + +def test_SharedParamsStatisticsWrapper_commoncase(h5f=None): + import numpy.random + + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", + arrays_names=('b1','b2','b3'), hdf5_file=h5f, + index_names=('epoch','minibatch')) + + b1 = DD({'value':numpy.random.rand(5)}) + b2 = DD({'value':numpy.random.rand(5)}) + b3 = DD({'value':numpy.random.rand(5)}) + stats.append((1,1), [b1,b2,b3]) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + b1_table = h5f.getNode('/params', 'b1') + b3_table = h5f.getNode('/params', 'b3') + + assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 + assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 + assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 + assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 + +def test_SharedParamsStatisticsWrapper_notimestamp(h5f=None): + import numpy.random + + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", + arrays_names=('b1','b2','b3'), hdf5_file=h5f, + index_names=('epoch','minibatch'), + store_timestamp=False) + + b1 = DD({'value':numpy.random.rand(5)}) + b2 = DD({'value':numpy.random.rand(5)}) + b3 = DD({'value':numpy.random.rand(5)}) + stats.append((1,1), [b1,b2,b3]) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + b1_table = h5f.getNode('/params', 'b1') + b3_table = h5f.getNode('/params', 'b3') + + assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 + assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 + assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 + assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 + + assert not ('timestamp' in dir(b1_table.cols)) + +def test_get_desc(): + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + desc = series._get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4")) + + mytable = h5f.createTable('/', 'mytable', desc) + + # just make sure the columns are there... otherwise this will throw an exception + mytable.cols.col1 + mytable.cols.col2 + mytable.cols.col3 + mytable.cols.col4 + + try: + # this should fail... LocalDescription must be local to get_desc_etc + test = LocalDescription + assert False + except: + assert True + + assert True + +def test_index_to_tuple_floaterror(): + try: + series._index_to_tuple(5.1) + assert False + except TypeError: + assert True + +def test_index_to_tuple_arrayok(): + tpl = series._index_to_tuple([1,2,3]) + assert type(tpl) == tuple and tpl[1] == 2 and tpl[2] == 3 + +def test_index_to_tuple_intbecomestuple(): + tpl = series._index_to_tuple(32) + + assert type(tpl) == tuple and tpl == (32,) + +def test_index_to_tuple_longbecomestuple(): + tpl = series._index_to_tuple(928374928374928L) + + assert type(tpl) == tuple and tpl == (928374928374928L,) + +if __name__ == '__main__': + import tempfile + test_get_desc() + test_ErrorSeries_common_case() + test_BasicStatisticsSeries_common_case() + test_AccumulatorSeriesWrapper_common_case() + test_SharedParamsStatisticsWrapper_commoncase() +