Mercurial > pylearn
changeset 929:34d1cd516f76
Added other targets (printing to stdout, notably) to seriestables, and corresponding doc
author | fsavard |
---|---|
date | Thu, 15 Apr 2010 09:11:14 -0400 |
parents | 87d416e1f4fd |
children | b2a60af9cc28 |
files | doc/seriestables.txt pylearn/io/seriestables/__init__.py pylearn/io/seriestables/series.py pylearn/io/seriestables/test_series.py |
diffstat | 4 files changed, 209 insertions(+), 13 deletions(-) [+] |
line wrap: on
line diff
--- a/doc/seriestables.txt Fri Apr 09 11:14:24 2010 -0400 +++ b/doc/seriestables.txt Thu Apr 15 09:11:14 2010 -0400 @@ -212,6 +212,74 @@ # where each element is a shared (as in theano.shared) array series['params'].append((epoch,), all_params) +Other targets for appending (e.g. printing to stdout) +----------------------------------------------------- + +SeriesTables was created with an HDF5 file in mind, but often, for debugging, +it's useful to be able to redirect the series elsewhere, notably the standard +output. A mechanism was added to do just that. + +What you do is you create a ``AppendTarget`` instance (or more than one) and +pass it as an argument to the Series constructor. For example, to print every +row appended to the standard output, you use StdoutAppendTarget. + +If you want to skip appending to the HDF5 file entirely, this is also +possible. You simply specify ``skip_hdf5_append=True`` in the constructor. You +still need to pass in a valid HDF5 file, though, even though nothing will be +written to it (for, err, legacy reasons). + +Here's an example: + +.. code-block:: python + + def create_series(num_hidden_layers): + + # Replace series we don't want to save with DummySeries, e.g. + # series['training_error'] = DummySeries() + + series = {} + + basedir = os.getcwd() + + h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w") + + # Here we create the new target, with a message prepended + # before every row is printed to stdout + stdout_target = \ + StdoutAppendTarget( \ + prepend='\n-----------------\nValidation error', + indent_str='\t') + + # Notice here we won't even write to the HDF5 file + series['validation_error'] = \ + ErrorSeries(error_name="validation_error", + table_name="validation_error", + hdf5_file=h5f, + index_names=('epoch',), + other_targets=[stdout_target], + skip_hdf5_append=True) + + return series + + +Now calls to series['validation_error'].append() will print to stdout outputs +like:: + + ---------------- + Validation error + timestamp : 1271202144 + cpuclock : 0.12 + epoch : 1 + validation_error : 30.0 + + ---------------- + Validation error + timestamp : 1271202144 + cpuclock : 0.12 + epoch : 2 + validation_error : 26.0 + + Visualizing in vitables ----------------------- @@ -220,3 +288,4 @@ .. _vitables: http://vitables.berlios.de/ .. image:: images/vitables_example_series.png +
--- a/pylearn/io/seriestables/__init__.py Fri Apr 09 11:14:24 2010 -0400 +++ b/pylearn/io/seriestables/__init__.py Thu Apr 15 09:11:14 2010 -0400 @@ -1,2 +1,2 @@ -from series import ErrorSeries, BasicStatisticsSeries, AccumulatorSeriesWrapper, SeriesArrayWrapper, SharedParamsStatisticsWrapper, DummySeries +from series import ErrorSeries, BasicStatisticsSeries, AccumulatorSeriesWrapper, SeriesArrayWrapper, SharedParamsStatisticsWrapper, DummySeries, StdoutAppendTarget, AppendTarget
--- a/pylearn/io/seriestables/series.py Fri Apr 09 11:14:24 2010 -0400 +++ b/pylearn/io/seriestables/series.py Thu Apr 15 09:11:14 2010 -0400 @@ -102,6 +102,52 @@ return LocalDescription + +############################################################################## +# Generic target helpers, other than HDF5 itself + +class AppendTarget(object): + def __init__(self): + pass + + def append(self, table, row): + pass + +class StdoutAppendTarget(AppendTarget): + ''' + Every append() translates into the row being printed on stdout, + each field on a line of the form "column_name : value" + ''' + def __init__(self, prepend='\n', indent_str='\t'): + ''' + Parameters + ---------- + prepend : str + String to prepend before each "append()" is dumped on stdout. + indent_str : str + Chars to prepend to each line + ''' + self.prepend = prepend + self.indent_str = indent_str + + def append(self, table, row): + print self.prepend + pretty_print_row(table, row, self.indent_str) + +def pretty_print_row(table, row, indent): + for key in table.colnames: + print indent, key, ":", row[key] + +class CallbackAppendTarget(AppendTarget): + ''' + Mostly to be used for tests. + ''' + def __init__(self, callback): + self.callback = callback + + def append(self, table, row): + self.callback(table, row) + ############################################################################## # Series classes @@ -122,7 +168,7 @@ except: raise TypeError("index must be a tuple of integers, or at least a single integer") -class Series(): +class Series(object): """ Base Series class, with minimal arguments and type checks. @@ -131,7 +177,8 @@ def __init__(self, table_name, hdf5_file, index_names=('epoch',), title="", hdf5_group='/', - store_timestamp=True, store_cpuclock=True): + store_timestamp=True, store_cpuclock=True, + other_targets=[], skip_hdf5_append=False): """Basic arguments each Series must get. Parameters @@ -159,6 +206,8 @@ store_cpuclock : bool Whether to create a column for cpu clock and store it with each record. + other_targets : list of str or AppendTarget instances + """ ######################################### @@ -195,6 +244,16 @@ if type(store_cpuclock) != bool: raise TypeError("store_timestamp must be a bool") + if type(other_targets) != list: + raise TypeError("other_targets must be a list") + else: + for t in other_targets: + if not isinstance(t, AppendTarget): + raise TypeError("other_targets elements must be instances of AppendTarget") + + if type(skip_hdf5_append) != bool: + raise TypeError("skip_hdf5_append must be a bool") + ######################################### self.table_name = table_name @@ -206,6 +265,9 @@ self.store_timestamp = store_timestamp self.store_cpuclock = store_cpuclock + self.other_targets = other_targets + self.skip_hdf5_append = skip_hdf5_append + def append(self, index, element): raise NotImplementedError @@ -244,7 +306,8 @@ def __init__(self, error_name, table_name, hdf5_file, index_names=('epoch',), title="", hdf5_group='/', - store_timestamp=True, store_cpuclock=True): + store_timestamp=True, store_cpuclock=True, + other_targets=[], skip_hdf5_append=False): """ For most parameters, see Series.__init__ @@ -257,7 +320,9 @@ # most type/value checks are performed in Series.__init__ Series.__init__(self, table_name, hdf5_file, index_names, title, store_timestamp=store_timestamp, - store_cpuclock=store_cpuclock) + store_cpuclock=store_cpuclock, + other_targets=other_targets, + skip_hdf5_append=skip_hdf5_append) if type(error_name) != str: raise TypeError("error_name must be a string") @@ -313,9 +378,13 @@ # adds timestamp and cpuclock to newrow if necessary self._timestamp_cpuclock(newrow) - newrow.append() + for t in self.other_targets: + t.append(self._table, newrow) - self.hdf5_file.flush() + if not self.skip_hdf5_append: + newrow.append() + + self.hdf5_file.flush() # Does not inherit from Series because it does not itself need to # access the hdf5_file and does not need a series_name (provided @@ -411,7 +480,8 @@ def __init__(self, table_name, hdf5_file, stats_functions=_basic_stats_functions, index_names=('epoch',), title="", hdf5_group='/', - store_timestamp=True, store_cpuclock=True): + store_timestamp=True, store_cpuclock=True, + other_targets=[], skip_hdf5_append=False): """ For most parameters, see Series.__init__ @@ -429,7 +499,9 @@ # Most type/value checks performed in Series.__init__ Series.__init__(self, table_name, hdf5_file, index_names, title, store_timestamp=store_timestamp, - store_cpuclock=store_cpuclock) + store_cpuclock=store_cpuclock, + other_targets=other_targets, + skip_hdf5_append=skip_hdf5_append) if type(hdf5_group) != str: raise TypeError("hdf5_group must be a string") @@ -487,9 +559,13 @@ self._timestamp_cpuclock(newrow) - newrow.append() + for t in self.other_targets: + t.append(self._table, newrow) - self.hdf5_file.flush() + if not self.skip_hdf5_append: + newrow.append() + + self.hdf5_file.flush() class SeriesArrayWrapper(): """ @@ -546,7 +622,8 @@ def __init__(self, arrays_names, new_group_name, hdf5_file, base_group='/', index_names=('epoch',), title="", - store_timestamp=True, store_cpuclock=True): + store_timestamp=True, store_cpuclock=True, + other_targets=[], skip_hdf5_append=False): """ For other parameters, see Series.__init__ @@ -598,7 +675,9 @@ stats_functions=stats_functions, hdf5_group=new_group._v_pathname, store_timestamp=store_timestamp, - store_cpuclock=store_cpuclock)) + store_cpuclock=store_cpuclock, + other_targets=other_targets, + skip_hdf5_append=skip_hdf5_append)) SeriesArrayWrapper.__init__(self, base_series_list)
--- a/pylearn/io/seriestables/test_series.py Fri Apr 09 11:14:24 2010 -0400 +++ b/pylearn/io/seriestables/test_series.py Thu Apr 15 09:11:14 2010 -0400 @@ -64,6 +64,7 @@ assert compare_lists(table.cols.epoch[:], [1,1,2,2]) assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) + assert len(table) == 4 def test_ErrorSeries_no_index(h5f=None): if not h5f: @@ -301,6 +302,52 @@ assert type(tpl) == tuple and tpl == (928374928374928L,) + + + + +def test_ErrorSeries_appendtarget(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_errors_from_callback = [] + + def my_callback(table, row): + validation_errors_from_callback.append(row['validation_error']) + + my_callback_target = CallbackAppendTarget(my_callback) + + validation_error = series.ErrorSeries(error_name="validation_error", + table_name="validation_error", + hdf5_file=h5f, + index_names=('minibatch',), + title="Validation error with no index", + other_targets=[my_callback_target], + skip_hdf5_append=True) + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append(2, 32.0) + validation_error.append(3, 30.0) + validation_error.append(4, 28.0) + validation_error.append(5, 26.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + # h5f should be empty + assert len(table) == 0 + + assert compare_lists(validation_errors_from_callback, [32.0,30.0,28.0,26.0]) + + + + + + if __name__ == '__main__': import tempfile test_get_desc() @@ -308,4 +355,5 @@ test_BasicStatisticsSeries_common_case() test_AccumulatorSeriesWrapper_common_case() test_SharedParamsStatisticsWrapper_commoncase() + test_ErrorSeries_appendtarget()