Mercurial > ift6266
diff utils/seriestables/test_series.py @ 224:0515a8901c6a
Corrigé un bug avec store_timestamp/cpuclock, et tests pour éviter ce cas
author | fsavard |
---|---|
date | Thu, 11 Mar 2010 11:52:43 -0500 |
parents | e172ef73cdc5 |
children | bfe20d63f88c |
line wrap: on
line diff
--- a/utils/seriestables/test_series.py Thu Mar 11 11:08:42 2010 -0500 +++ b/utils/seriestables/test_series.py Thu Mar 11 11:52:43 2010 -0500 @@ -65,6 +65,52 @@ assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) +def test_ErrorSeries_notimestamp(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Validation error indexed by epoch and minibatch", + store_timestamp=False) + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append((1,1), 32.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + assert compare_lists(table.cols.epoch[:], [1]) + assert not ("timestamp" in dir(table.cols)) + assert "cpuclock" in dir(table.cols) + +def test_ErrorSeries_nocpuclock(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Validation error indexed by epoch and minibatch", + store_cpuclock=False) + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append((1,1), 32.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + assert compare_lists(table.cols.epoch[:], [1]) + assert not ("cpuclock" in dir(table.cols)) + assert "timestamp" in dir(table.cols) + def test_AccumulatorSeriesWrapper_common_case(h5f=None): if not h5f: h5f_path = tempfile.NamedTemporaryFile().name @@ -153,6 +199,37 @@ assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 +def test_SharedParamsStatisticsWrapper_notimestamp(h5f=None): + import numpy.random + + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", + arrays_names=('b1','b2','b3'), hdf5_file=h5f, + index_names=('epoch','minibatch'), + store_timestamp=False) + + b1 = DD({'value':numpy.random.rand(5)}) + b2 = DD({'value':numpy.random.rand(5)}) + b3 = DD({'value':numpy.random.rand(5)}) + stats.append((1,1), [b1,b2,b3]) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + b1_table = h5f.getNode('/params', 'b1') + b3_table = h5f.getNode('/params', 'b3') + + assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 + assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 + assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 + assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 + + assert not ('timestamp' in dir(b1_table.cols)) + def test_get_desc(): h5f_path = tempfile.NamedTemporaryFile().name h5f = tables.openFile(h5f_path, "w")