Mercurial > ift6266
diff utils/tables_series/test_series.py @ 210:dc0d77c8a878
Commented table_series code, changed ParamsStatisticsArray to take shared params instead, create DummySeries to use when we don't want to save a named series
author | savardf |
---|---|
date | Tue, 09 Mar 2010 10:15:19 -0500 |
parents | acb942530923 |
children |
line wrap: on
line diff
--- a/utils/tables_series/test_series.py Fri Mar 05 18:08:34 2010 -0500 +++ b/utils/tables_series/test_series.py Tue Mar 09 10:15:19 2010 -0500 @@ -1,6 +1,9 @@ import tempfile import numpy import numpy.random + +from jobman import DD + from tables import * from series import * @@ -109,20 +112,20 @@ assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True) assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939, 0.35640177, 0.85724366], floats=True) -def test_ParamsStatisticsWrapper_commoncase(h5f=None): +def test_SharedParamsStatisticsWrapper_commoncase(h5f=None): import numpy.random if not h5f: h5f_path = tempfile.NamedTemporaryFile().name h5f = openFile(h5f_path, "w") - stats = ParamsStatisticsWrapper(new_group_name="params", base_group="/", + stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", arrays_names=('b1','b2','b3'), hdf5_file=h5f, index_names=('epoch','minibatch')) - b1 = numpy.random.rand(5) - b2 = numpy.random.rand(5) - b3 = numpy.random.rand(5) + b1 = DD({'value':numpy.random.rand(5)}) + b2 = DD({'value':numpy.random.rand(5)}) + b3 = DD({'value':numpy.random.rand(5)}) stats.append((1,1), [b1,b2,b3]) h5f.close() @@ -132,10 +135,10 @@ b1_table = h5f.getNode('/params', 'b1') b3_table = h5f.getNode('/params', 'b3') - assert b1_table.cols.mean[0] - numpy.mean(b1) < 1e-3 - assert b3_table.cols.mean[0] - numpy.mean(b3) < 1e-3 - assert b1_table.cols.min[0] - numpy.min(b1) < 1e-3 - assert b3_table.cols.min[0] - numpy.min(b3) < 1e-3 + assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 + assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 + assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 + assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 def test_get_desc(): h5f_path = tempfile.NamedTemporaryFile().name @@ -166,5 +169,5 @@ test_ErrorSeries_common_case() test_BasicStatisticsSeries_common_case() test_AccumulatorSeriesWrapper_common_case() - test_ParamsStatisticsWrapper_commoncase() + test_SharedParamsStatisticsWrapper_commoncase()