diff utils/seriestables/test_series.py @ 213:a96fa4de06d2

Renommé mon module de séries
author fsavard
date Wed, 10 Mar 2010 16:52:22 -0500
parents utils/tables_series/test_series.py@dc0d77c8a878
children 4c137f16b013
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/utils/seriestables/test_series.py	Wed Mar 10 16:52:22 2010 -0500
@@ -0,0 +1,173 @@
+import tempfile
+import numpy
+import numpy.random
+
+from jobman import DD
+
+from tables import *
+
+from series import *
+
+
+def compare_floats(f1,f2):
+    if f1-f2 < 1e-3:
+        return True
+    return False
+
+def compare_lists(it1, it2, floats=False):
+    if len(it1) != len(it2):
+        return False
+
+    for el1,  el2 in zip(it1, it2):
+        if floats:
+            if not compare_floats(el1,el2):
+                return False
+        elif el1 != el2:
+            return False
+
+    return True
+
+def test_ErrorSeries_common_case(h5f=None):
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = openFile(h5f_path, "w")
+
+    validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error",
+                                hdf5_file=h5f, index_names=('epoch','minibatch'),
+                                title="Validation error indexed by epoch and minibatch")
+
+    # (1,1), (1,2) etc. are (epoch, minibatch) index
+    validation_error.append((1,1), 32.0)
+    validation_error.append((1,2), 30.0)
+    validation_error.append((2,1), 28.0)
+    validation_error.append((2,2), 26.0)
+
+    h5f.close()
+
+    h5f = openFile(h5f_path, "r")
+    
+    table = h5f.getNode('/', 'validation_error')
+
+    assert compare_lists(table.cols.epoch[:], [1,1,2,2])
+    assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
+    assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
+
+def test_AccumulatorSeriesWrapper_common_case(h5f=None):
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = openFile(h5f_path, "w")
+
+    validation_error = ErrorSeries(error_name="accumulated_validation_error",
+                                table_name="accumulated_validation_error",
+                                hdf5_file=h5f,
+                                index_names=('epoch','minibatch'),
+                                title="Validation error, summed every 3 minibatches, indexed by epoch and minibatch")
+
+    accumulator = AccumulatorSeriesWrapper(base_series=validation_error,
+                                    reduce_every=3, reduce_function=numpy.sum)
+
+    # (1,1), (1,2) etc. are (epoch, minibatch) index
+    accumulator.append((1,1), 32.0)
+    accumulator.append((1,2), 30.0)
+    accumulator.append((2,1), 28.0)
+    accumulator.append((2,2), 26.0)
+    accumulator.append((3,1), 24.0)
+    accumulator.append((3,2), 22.0)
+
+    h5f.close()
+
+    h5f = openFile(h5f_path, "r")
+    
+    table = h5f.getNode('/', 'accumulated_validation_error')
+
+    assert compare_lists(table.cols.epoch[:], [2,3])
+    assert compare_lists(table.cols.minibatch[:], [1,2])
+    assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True)
+
+def test_BasicStatisticsSeries_common_case(h5f=None):
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = openFile(h5f_path, "w")
+
+    stats_series = BasicStatisticsSeries(table_name="b_vector_statistics",
+                                hdf5_file=h5f, index_names=('epoch','minibatch'),
+                                title="Basic statistics for b vector indexed by epoch and minibatch")
+
+    # (1,1), (1,2) etc. are (epoch, minibatch) index
+    stats_series.append((1,1), [0.15, 0.20, 0.30])
+    stats_series.append((1,2), [-0.18, 0.30, 0.58])
+    stats_series.append((2,1), [0.18, -0.38, -0.68])
+    stats_series.append((2,2), [0.15, 0.02, 1.9])
+
+    h5f.close()
+
+    h5f = openFile(h5f_path, "r")
+    
+    table = h5f.getNode('/', 'b_vector_statistics')
+
+    assert compare_lists(table.cols.epoch[:], [1,1,2,2])
+    assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
+    assert compare_lists(table.cols.mean[:], [0.21666667,  0.23333333, -0.29333332,  0.69], floats=True)
+    assert compare_lists(table.cols.min[:], [0.15000001, -0.18000001, -0.68000001,  0.02], floats=True)
+    assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True)
+    assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939,  0.35640177, 0.85724366], floats=True)
+
+def test_SharedParamsStatisticsWrapper_commoncase(h5f=None):
+    import numpy.random
+
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = openFile(h5f_path, "w")
+
+    stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
+                                arrays_names=('b1','b2','b3'), hdf5_file=h5f,
+                                index_names=('epoch','minibatch'))
+
+    b1 = DD({'value':numpy.random.rand(5)})
+    b2 = DD({'value':numpy.random.rand(5)})
+    b3 = DD({'value':numpy.random.rand(5)})
+    stats.append((1,1), [b1,b2,b3])
+
+    h5f.close()
+
+    h5f = openFile(h5f_path, "r")
+
+    b1_table = h5f.getNode('/params', 'b1')
+    b3_table = h5f.getNode('/params', 'b3')
+
+    assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
+    assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3
+    assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
+    assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
+
+def test_get_desc():
+    h5f_path = tempfile.NamedTemporaryFile().name
+    h5f = openFile(h5f_path, "w")
+
+    desc = get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4"))
+
+    mytable = h5f.createTable('/', 'mytable', desc)
+
+    # just make sure the columns are there... otherwise this will throw an exception
+    mytable.cols.col1
+    mytable.cols.col2
+    mytable.cols.col3
+    mytable.cols.col4
+
+    try:
+        # this should fail... LocalDescription must be local to get_desc_etc
+        test = LocalDescription
+        assert False
+    except:
+        assert True
+
+    assert True
+
+if __name__ == '__main__':
+    import tempfile
+    test_get_desc()
+    test_ErrorSeries_common_case()
+    test_BasicStatisticsSeries_common_case()
+    test_AccumulatorSeriesWrapper_common_case()
+    test_SharedParamsStatisticsWrapper_commoncase()
+