Mercurial > ift6266
diff utils/seriestables/test_series.py @ 213:a96fa4de06d2
Renommé mon module de séries
author | fsavard |
---|---|
date | Wed, 10 Mar 2010 16:52:22 -0500 |
parents | utils/tables_series/test_series.py@dc0d77c8a878 |
children | 4c137f16b013 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/utils/seriestables/test_series.py Wed Mar 10 16:52:22 2010 -0500 @@ -0,0 +1,173 @@ +import tempfile +import numpy +import numpy.random + +from jobman import DD + +from tables import * + +from series import * + + +def compare_floats(f1,f2): + if f1-f2 < 1e-3: + return True + return False + +def compare_lists(it1, it2, floats=False): + if len(it1) != len(it2): + return False + + for el1, el2 in zip(it1, it2): + if floats: + if not compare_floats(el1,el2): + return False + elif el1 != el2: + return False + + return True + +def test_ErrorSeries_common_case(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = openFile(h5f_path, "w") + + validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Validation error indexed by epoch and minibatch") + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append((1,1), 32.0) + validation_error.append((1,2), 30.0) + validation_error.append((2,1), 28.0) + validation_error.append((2,2), 26.0) + + h5f.close() + + h5f = openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + assert compare_lists(table.cols.epoch[:], [1,1,2,2]) + assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) + assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) + +def test_AccumulatorSeriesWrapper_common_case(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = openFile(h5f_path, "w") + + validation_error = ErrorSeries(error_name="accumulated_validation_error", + table_name="accumulated_validation_error", + hdf5_file=h5f, + index_names=('epoch','minibatch'), + title="Validation error, summed every 3 minibatches, indexed by epoch and minibatch") + + accumulator = AccumulatorSeriesWrapper(base_series=validation_error, + reduce_every=3, reduce_function=numpy.sum) + + # (1,1), (1,2) etc. are (epoch, minibatch) index + accumulator.append((1,1), 32.0) + accumulator.append((1,2), 30.0) + accumulator.append((2,1), 28.0) + accumulator.append((2,2), 26.0) + accumulator.append((3,1), 24.0) + accumulator.append((3,2), 22.0) + + h5f.close() + + h5f = openFile(h5f_path, "r") + + table = h5f.getNode('/', 'accumulated_validation_error') + + assert compare_lists(table.cols.epoch[:], [2,3]) + assert compare_lists(table.cols.minibatch[:], [1,2]) + assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True) + +def test_BasicStatisticsSeries_common_case(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = openFile(h5f_path, "w") + + stats_series = BasicStatisticsSeries(table_name="b_vector_statistics", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Basic statistics for b vector indexed by epoch and minibatch") + + # (1,1), (1,2) etc. are (epoch, minibatch) index + stats_series.append((1,1), [0.15, 0.20, 0.30]) + stats_series.append((1,2), [-0.18, 0.30, 0.58]) + stats_series.append((2,1), [0.18, -0.38, -0.68]) + stats_series.append((2,2), [0.15, 0.02, 1.9]) + + h5f.close() + + h5f = openFile(h5f_path, "r") + + table = h5f.getNode('/', 'b_vector_statistics') + + assert compare_lists(table.cols.epoch[:], [1,1,2,2]) + assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) + assert compare_lists(table.cols.mean[:], [0.21666667, 0.23333333, -0.29333332, 0.69], floats=True) + assert compare_lists(table.cols.min[:], [0.15000001, -0.18000001, -0.68000001, 0.02], floats=True) + assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True) + assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939, 0.35640177, 0.85724366], floats=True) + +def test_SharedParamsStatisticsWrapper_commoncase(h5f=None): + import numpy.random + + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = openFile(h5f_path, "w") + + stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", + arrays_names=('b1','b2','b3'), hdf5_file=h5f, + index_names=('epoch','minibatch')) + + b1 = DD({'value':numpy.random.rand(5)}) + b2 = DD({'value':numpy.random.rand(5)}) + b3 = DD({'value':numpy.random.rand(5)}) + stats.append((1,1), [b1,b2,b3]) + + h5f.close() + + h5f = openFile(h5f_path, "r") + + b1_table = h5f.getNode('/params', 'b1') + b3_table = h5f.getNode('/params', 'b3') + + assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 + assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 + assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 + assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 + +def test_get_desc(): + h5f_path = tempfile.NamedTemporaryFile().name + h5f = openFile(h5f_path, "w") + + desc = get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4")) + + mytable = h5f.createTable('/', 'mytable', desc) + + # just make sure the columns are there... otherwise this will throw an exception + mytable.cols.col1 + mytable.cols.col2 + mytable.cols.col3 + mytable.cols.col4 + + try: + # this should fail... LocalDescription must be local to get_desc_etc + test = LocalDescription + assert False + except: + assert True + + assert True + +if __name__ == '__main__': + import tempfile + test_get_desc() + test_ErrorSeries_common_case() + test_BasicStatisticsSeries_common_case() + test_AccumulatorSeriesWrapper_common_case() + test_SharedParamsStatisticsWrapper_commoncase() +