changeset 929:34d1cd516f76

Added other targets (printing to stdout, notably) to seriestables, and corresponding doc
author fsavard
date Thu, 15 Apr 2010 09:11:14 -0400
parents 87d416e1f4fd
children b2a60af9cc28
files doc/seriestables.txt pylearn/io/seriestables/__init__.py pylearn/io/seriestables/series.py pylearn/io/seriestables/test_series.py
diffstat 4 files changed, 209 insertions(+), 13 deletions(-) [+]
line wrap: on
line diff
--- a/doc/seriestables.txt	Fri Apr 09 11:14:24 2010 -0400
+++ b/doc/seriestables.txt	Thu Apr 15 09:11:14 2010 -0400
@@ -212,6 +212,74 @@
 		# where each element is a shared (as in theano.shared) array
 		series['params'].append((epoch,), all_params)
 
+Other targets for appending (e.g. printing to stdout)
+-----------------------------------------------------
+
+SeriesTables was created with an HDF5 file in mind, but often, for debugging,
+it's useful to be able to redirect the series elsewhere, notably the standard
+output. A mechanism was added to do just that.
+
+What you do is you create a ``AppendTarget`` instance (or more than one) and
+pass it as an argument to the Series constructor. For example, to print every
+row appended to the standard output, you use StdoutAppendTarget.
+
+If you want to skip appending to the HDF5 file entirely, this is also
+possible. You simply specify ``skip_hdf5_append=True`` in the constructor. You
+still need to pass in a valid HDF5 file, though, even though nothing will be
+written to it (for, err, legacy reasons).
+
+Here's an example:
+
+.. code-block:: python
+
+	def create_series(num_hidden_layers):
+
+		# Replace series we don't want to save with DummySeries, e.g.
+		# series['training_error'] = DummySeries()
+
+		series = {}
+
+		basedir = os.getcwd()
+
+		h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w")
+
+		# Here we create the new target, with a message prepended
+		# before every row is printed to stdout
+		stdout_target = \
+			StdoutAppendTarget( \
+				prepend='\n-----------------\nValidation error',
+				indent_str='\t')
+
+		# Notice here we won't even write to the HDF5 file
+		series['validation_error'] = \
+			ErrorSeries(error_name="validation_error",
+				table_name="validation_error",
+				hdf5_file=h5f,
+				index_names=('epoch',),
+				other_targets=[stdout_target],
+				skip_hdf5_append=True)
+
+		return series
+
+		
+Now calls to series['validation_error'].append() will print to stdout outputs
+like::
+
+	----------------
+	Validation error
+		timestamp : 1271202144
+		cpuclock : 0.12
+		epoch : 1
+		validation_error : 30.0
+
+	----------------
+	Validation error
+		timestamp : 1271202144
+		cpuclock : 0.12
+		epoch : 2
+		validation_error : 26.0
+
+
 Visualizing in vitables
 -----------------------
 
@@ -220,3 +288,4 @@
 .. _vitables: http://vitables.berlios.de/
 
 .. image:: images/vitables_example_series.png
+
--- a/pylearn/io/seriestables/__init__.py	Fri Apr 09 11:14:24 2010 -0400
+++ b/pylearn/io/seriestables/__init__.py	Thu Apr 15 09:11:14 2010 -0400
@@ -1,2 +1,2 @@
-from series import ErrorSeries, BasicStatisticsSeries, AccumulatorSeriesWrapper, SeriesArrayWrapper, SharedParamsStatisticsWrapper, DummySeries
+from series import ErrorSeries, BasicStatisticsSeries, AccumulatorSeriesWrapper, SeriesArrayWrapper, SharedParamsStatisticsWrapper, DummySeries, StdoutAppendTarget, AppendTarget
 
--- a/pylearn/io/seriestables/series.py	Fri Apr 09 11:14:24 2010 -0400
+++ b/pylearn/io/seriestables/series.py	Thu Apr 15 09:11:14 2010 -0400
@@ -102,6 +102,52 @@
 
     return LocalDescription
 
+
+##############################################################################
+# Generic target helpers, other than HDF5 itself
+
+class AppendTarget(object):
+    def __init__(self):
+        pass
+
+    def append(self, table, row):
+        pass
+
+class StdoutAppendTarget(AppendTarget):
+    '''
+    Every append() translates into the row being printed on stdout,
+    each field on a line of the form "column_name : value"
+    '''
+    def __init__(self, prepend='\n', indent_str='\t'):
+        '''
+        Parameters
+        ----------
+        prepend : str
+            String to prepend before each "append()" is dumped on stdout.
+        indent_str : str
+            Chars to prepend to each line
+        '''
+        self.prepend = prepend
+        self.indent_str = indent_str
+
+    def append(self, table, row):
+        print self.prepend
+        pretty_print_row(table, row, self.indent_str)
+
+def pretty_print_row(table, row, indent):
+    for key in table.colnames:
+        print indent, key, ":", row[key]
+
+class CallbackAppendTarget(AppendTarget):
+    '''
+    Mostly to be used for tests.
+    '''
+    def __init__(self, callback):
+        self.callback = callback
+
+    def append(self, table, row):
+        self.callback(table, row)
+
 ##############################################################################
 # Series classes
 
@@ -122,7 +168,7 @@
     except:
         raise TypeError("index must be a tuple of integers, or at least a single integer")
 
-class Series():
+class Series(object):
     """
     Base Series class, with minimal arguments and type checks. 
 
@@ -131,7 +177,8 @@
 
     def __init__(self, table_name, hdf5_file, index_names=('epoch',), 
                     title="", hdf5_group='/', 
-                    store_timestamp=True, store_cpuclock=True):
+                    store_timestamp=True, store_cpuclock=True,
+                    other_targets=[], skip_hdf5_append=False):
         """Basic arguments each Series must get.
 
         Parameters
@@ -159,6 +206,8 @@
         store_cpuclock : bool
             Whether to create a column for cpu clock and store it with 
             each record.
+        other_targets : list of str or AppendTarget instances
+            
         """
 
         #########################################
@@ -195,6 +244,16 @@
         if type(store_cpuclock) != bool:
             raise TypeError("store_timestamp must be a bool")
 
+        if type(other_targets) != list:
+            raise TypeError("other_targets must be a list")
+        else:
+            for t in other_targets:
+                if not isinstance(t, AppendTarget):
+                    raise TypeError("other_targets elements must be instances of AppendTarget")
+
+        if type(skip_hdf5_append) != bool:
+            raise TypeError("skip_hdf5_append must be a bool")
+
         #########################################
 
         self.table_name = table_name
@@ -206,6 +265,9 @@
         self.store_timestamp = store_timestamp
         self.store_cpuclock = store_cpuclock
 
+        self.other_targets = other_targets
+        self.skip_hdf5_append = skip_hdf5_append
+
     def append(self, index, element):
         raise NotImplementedError
 
@@ -244,7 +306,8 @@
     def __init__(self, error_name, table_name, 
                     hdf5_file, index_names=('epoch',), 
                     title="", hdf5_group='/', 
-                    store_timestamp=True, store_cpuclock=True):
+                    store_timestamp=True, store_cpuclock=True,
+                    other_targets=[], skip_hdf5_append=False):
         """
         For most parameters, see Series.__init__
 
@@ -257,7 +320,9 @@
         # most type/value checks are performed in Series.__init__
         Series.__init__(self, table_name, hdf5_file, index_names, title, 
                             store_timestamp=store_timestamp,
-                            store_cpuclock=store_cpuclock)
+                            store_cpuclock=store_cpuclock,
+                            other_targets=other_targets,
+                            skip_hdf5_append=skip_hdf5_append)
 
         if type(error_name) != str:
             raise TypeError("error_name must be a string")
@@ -313,9 +378,13 @@
         # adds timestamp and cpuclock to newrow if necessary
         self._timestamp_cpuclock(newrow)
 
-        newrow.append()
+        for t in self.other_targets:
+            t.append(self._table, newrow)
 
-        self.hdf5_file.flush()
+        if not self.skip_hdf5_append:
+            newrow.append()
+
+            self.hdf5_file.flush()
 
 # Does not inherit from Series because it does not itself need to
 # access the hdf5_file and does not need a series_name (provided
@@ -411,7 +480,8 @@
     def __init__(self, table_name, hdf5_file, 
                     stats_functions=_basic_stats_functions, 
                     index_names=('epoch',), title="", hdf5_group='/', 
-                    store_timestamp=True, store_cpuclock=True):
+                    store_timestamp=True, store_cpuclock=True,
+                    other_targets=[], skip_hdf5_append=False):
         """
         For most parameters, see Series.__init__
 
@@ -429,7 +499,9 @@
         # Most type/value checks performed in Series.__init__
         Series.__init__(self, table_name, hdf5_file, index_names, title, 
                             store_timestamp=store_timestamp,
-                            store_cpuclock=store_cpuclock)
+                            store_cpuclock=store_cpuclock,
+                            other_targets=other_targets,
+                            skip_hdf5_append=skip_hdf5_append)
 
         if type(hdf5_group) != str:
             raise TypeError("hdf5_group must be a string")
@@ -487,9 +559,13 @@
 
         self._timestamp_cpuclock(newrow)
 
-        newrow.append()
+        for t in self.other_targets:
+            t.append(self._table, newrow)
 
-        self.hdf5_file.flush()
+        if not self.skip_hdf5_append:
+            newrow.append()
+
+            self.hdf5_file.flush()
 
 class SeriesArrayWrapper():
     """
@@ -546,7 +622,8 @@
 
     def __init__(self, arrays_names, new_group_name, hdf5_file,
                     base_group='/', index_names=('epoch',), title="",
-                    store_timestamp=True, store_cpuclock=True):
+                    store_timestamp=True, store_cpuclock=True,
+                    other_targets=[], skip_hdf5_append=False):
         """
         For other parameters, see Series.__init__
 
@@ -598,7 +675,9 @@
                                 stats_functions=stats_functions,
                                 hdf5_group=new_group._v_pathname,
                                 store_timestamp=store_timestamp,
-                                store_cpuclock=store_cpuclock))
+                                store_cpuclock=store_cpuclock,
+                                other_targets=other_targets,
+                                skip_hdf5_append=skip_hdf5_append))
 
         SeriesArrayWrapper.__init__(self, base_series_list)
 
--- a/pylearn/io/seriestables/test_series.py	Fri Apr 09 11:14:24 2010 -0400
+++ b/pylearn/io/seriestables/test_series.py	Thu Apr 15 09:11:14 2010 -0400
@@ -64,6 +64,7 @@
     assert compare_lists(table.cols.epoch[:], [1,1,2,2])
     assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
     assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
+    assert len(table) == 4
 
 def test_ErrorSeries_no_index(h5f=None):
     if not h5f:
@@ -301,6 +302,52 @@
 
     assert type(tpl) == tuple and tpl == (928374928374928L,)
 
+
+
+
+
+def test_ErrorSeries_appendtarget(h5f=None):
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = tables.openFile(h5f_path, "w")
+
+    validation_errors_from_callback = []
+
+    def my_callback(table, row):
+        validation_errors_from_callback.append(row['validation_error'])
+
+    my_callback_target = CallbackAppendTarget(my_callback)
+
+    validation_error = series.ErrorSeries(error_name="validation_error",
+                                table_name="validation_error",
+                                hdf5_file=h5f, 
+                                index_names=('minibatch',),
+                                title="Validation error with no index",
+                                other_targets=[my_callback_target],
+                                skip_hdf5_append=True)
+
+    # (1,1), (1,2) etc. are (epoch, minibatch) index
+    validation_error.append(2, 32.0)
+    validation_error.append(3, 30.0)
+    validation_error.append(4, 28.0)
+    validation_error.append(5, 26.0)
+
+    h5f.close()
+
+    h5f = tables.openFile(h5f_path, "r")
+    
+    table = h5f.getNode('/', 'validation_error')
+
+    # h5f should be empty
+    assert len(table) == 0
+
+    assert compare_lists(validation_errors_from_callback, [32.0,30.0,28.0,26.0])
+
+
+
+
+
+
 if __name__ == '__main__':
     import tempfile
     test_get_desc()
@@ -308,4 +355,5 @@
     test_BasicStatisticsSeries_common_case()
     test_AccumulatorSeriesWrapper_common_case()
     test_SharedParamsStatisticsWrapper_commoncase()
+    test_ErrorSeries_appendtarget()