Mercurial > pylearn
changeset 935:7305246f21f8
Fixed bug with ErrorSeries not supporting hdf5_group parameter. Added test to check for that bug.
author | fsavard |
---|---|
date | Sat, 17 Apr 2010 18:33:53 -0400 |
parents | e0b960ee57f5 |
children | f732ec90e249 |
files | pylearn/io/seriestables/series.py pylearn/io/seriestables/test_series.py |
diffstat | 2 files changed, 37 insertions(+), 2 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/io/seriestables/series.py Thu Apr 15 10:52:02 2010 -0400 +++ b/pylearn/io/seriestables/series.py Sat Apr 17 18:33:53 2010 -0400 @@ -118,7 +118,7 @@ Every append() translates into the row being printed on stdout, each field on a line of the form "column_name : value" ''' - def __init__(self, prepend='\n', indent_str='\t'): + def __init__(self, prepend=None, indent_str='\t'): ''' Parameters ---------- @@ -131,7 +131,10 @@ self.indent_str = indent_str def append(self, table, row): - print self.prepend + if not self.prepend: + print table._v_pathname + else: + print self.prepend pretty_print_row(table, row, self.indent_str) def pretty_print_row(table, row, indent): @@ -319,6 +322,7 @@ # most type/value checks are performed in Series.__init__ Series.__init__(self, table_name, hdf5_file, index_names, title, + hdf5_group=hdf5_group, store_timestamp=store_timestamp, store_cpuclock=store_cpuclock, other_targets=other_targets,
--- a/pylearn/io/seriestables/test_series.py Thu Apr 15 10:52:02 2010 -0400 +++ b/pylearn/io/seriestables/test_series.py Sat Apr 17 18:33:53 2010 -0400 @@ -66,6 +66,37 @@ assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) assert len(table) == 4 + +def test_ErrorSeries_with_group(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + new_group = h5f.createGroup('/','generic_errors') + + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", + hdf5_file=h5f, index_names=('epoch','minibatch'), + hdf5_group=new_group._v_pathname, + title="Validation error indexed by epoch and minibatch") + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append((1,1), 32.0) + validation_error.append((1,2), 30.0) + validation_error.append((2,1), 28.0) + validation_error.append((2,2), 26.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/generic_errors', 'validation_error') + + assert compare_lists(table.cols.epoch[:], [1,1,2,2]) + assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) + assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) + assert len(table) == 4 + + def test_ErrorSeries_no_index(h5f=None): if not h5f: h5f_path = tempfile.NamedTemporaryFile().name