diff utils/seriestables/test_series.py @ 224:0515a8901c6a

Corrigé un bug avec store_timestamp/cpuclock, et tests pour éviter ce cas
author fsavard
date Thu, 11 Mar 2010 11:52:43 -0500
parents e172ef73cdc5
children bfe20d63f88c
line wrap: on
line diff
--- a/utils/seriestables/test_series.py	Thu Mar 11 11:08:42 2010 -0500
+++ b/utils/seriestables/test_series.py	Thu Mar 11 11:52:43 2010 -0500
@@ -65,6 +65,52 @@
     assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
     assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
 
+def test_ErrorSeries_notimestamp(h5f=None):
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = tables.openFile(h5f_path, "w")
+
+    validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
+                                hdf5_file=h5f, index_names=('epoch','minibatch'),
+                                title="Validation error indexed by epoch and minibatch", 
+                                store_timestamp=False)
+
+    # (1,1), (1,2) etc. are (epoch, minibatch) index
+    validation_error.append((1,1), 32.0)
+
+    h5f.close()
+
+    h5f = tables.openFile(h5f_path, "r")
+    
+    table = h5f.getNode('/', 'validation_error')
+
+    assert compare_lists(table.cols.epoch[:], [1])
+    assert not ("timestamp" in dir(table.cols))
+    assert "cpuclock" in dir(table.cols)
+
+def test_ErrorSeries_nocpuclock(h5f=None):
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = tables.openFile(h5f_path, "w")
+
+    validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
+                                hdf5_file=h5f, index_names=('epoch','minibatch'),
+                                title="Validation error indexed by epoch and minibatch", 
+                                store_cpuclock=False)
+
+    # (1,1), (1,2) etc. are (epoch, minibatch) index
+    validation_error.append((1,1), 32.0)
+
+    h5f.close()
+
+    h5f = tables.openFile(h5f_path, "r")
+    
+    table = h5f.getNode('/', 'validation_error')
+
+    assert compare_lists(table.cols.epoch[:], [1])
+    assert not ("cpuclock" in dir(table.cols))
+    assert "timestamp" in dir(table.cols)
+
 def test_AccumulatorSeriesWrapper_common_case(h5f=None):
     if not h5f:
         h5f_path = tempfile.NamedTemporaryFile().name
@@ -153,6 +199,37 @@
     assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
     assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
 
+def test_SharedParamsStatisticsWrapper_notimestamp(h5f=None):
+    import numpy.random
+
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = tables.openFile(h5f_path, "w")
+
+    stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
+                                arrays_names=('b1','b2','b3'), hdf5_file=h5f,
+                                index_names=('epoch','minibatch'),
+                                store_timestamp=False)
+
+    b1 = DD({'value':numpy.random.rand(5)})
+    b2 = DD({'value':numpy.random.rand(5)})
+    b3 = DD({'value':numpy.random.rand(5)})
+    stats.append((1,1), [b1,b2,b3])
+
+    h5f.close()
+
+    h5f = tables.openFile(h5f_path, "r")
+
+    b1_table = h5f.getNode('/params', 'b1')
+    b3_table = h5f.getNode('/params', 'b3')
+
+    assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
+    assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3
+    assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
+    assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
+
+    assert not ('timestamp' in dir(b1_table.cols))
+
 def test_get_desc():
     h5f_path = tempfile.NamedTemporaryFile().name
     h5f = tables.openFile(h5f_path, "w")