changeset 935:7305246f21f8

Fixed bug with ErrorSeries not supporting hdf5_group parameter. Added test to check for that bug.
author fsavard
date Sat, 17 Apr 2010 18:33:53 -0400
parents e0b960ee57f5
children f732ec90e249
files pylearn/io/seriestables/series.py pylearn/io/seriestables/test_series.py
diffstat 2 files changed, 37 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/io/seriestables/series.py	Thu Apr 15 10:52:02 2010 -0400
+++ b/pylearn/io/seriestables/series.py	Sat Apr 17 18:33:53 2010 -0400
@@ -118,7 +118,7 @@
     Every append() translates into the row being printed on stdout,
     each field on a line of the form "column_name : value"
     '''
-    def __init__(self, prepend='\n', indent_str='\t'):
+    def __init__(self, prepend=None, indent_str='\t'):
         '''
         Parameters
         ----------
@@ -131,7 +131,10 @@
         self.indent_str = indent_str
 
     def append(self, table, row):
-        print self.prepend
+        if not self.prepend:
+            print table._v_pathname
+        else:
+            print self.prepend
         pretty_print_row(table, row, self.indent_str)
 
 def pretty_print_row(table, row, indent):
@@ -319,6 +322,7 @@
 
         # most type/value checks are performed in Series.__init__
         Series.__init__(self, table_name, hdf5_file, index_names, title, 
+                            hdf5_group=hdf5_group,
                             store_timestamp=store_timestamp,
                             store_cpuclock=store_cpuclock,
                             other_targets=other_targets,
--- a/pylearn/io/seriestables/test_series.py	Thu Apr 15 10:52:02 2010 -0400
+++ b/pylearn/io/seriestables/test_series.py	Sat Apr 17 18:33:53 2010 -0400
@@ -66,6 +66,37 @@
     assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
     assert len(table) == 4
 
+
+def test_ErrorSeries_with_group(h5f=None):
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = tables.openFile(h5f_path, "w")
+
+    new_group = h5f.createGroup('/','generic_errors')
+
+    validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
+                                hdf5_file=h5f, index_names=('epoch','minibatch'),
+                                hdf5_group=new_group._v_pathname,
+                                title="Validation error indexed by epoch and minibatch")
+
+    # (1,1), (1,2) etc. are (epoch, minibatch) index
+    validation_error.append((1,1), 32.0)
+    validation_error.append((1,2), 30.0)
+    validation_error.append((2,1), 28.0)
+    validation_error.append((2,2), 26.0)
+
+    h5f.close()
+
+    h5f = tables.openFile(h5f_path, "r")
+    
+    table = h5f.getNode('/generic_errors', 'validation_error')
+
+    assert compare_lists(table.cols.epoch[:], [1,1,2,2])
+    assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
+    assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
+    assert len(table) == 4
+
+
 def test_ErrorSeries_no_index(h5f=None):
     if not h5f:
         h5f_path = tempfile.NamedTemporaryFile().name