diff utils/tables_series/test_series.py @ 210:dc0d77c8a878

Commented table_series code, changed ParamsStatisticsArray to take shared params instead, create DummySeries to use when we don't want to save a named series
author savardf
date Tue, 09 Mar 2010 10:15:19 -0500
parents acb942530923
children
line wrap: on
line diff
--- a/utils/tables_series/test_series.py	Fri Mar 05 18:08:34 2010 -0500
+++ b/utils/tables_series/test_series.py	Tue Mar 09 10:15:19 2010 -0500
@@ -1,6 +1,9 @@
 import tempfile
 import numpy
 import numpy.random
+
+from jobman import DD
+
 from tables import *
 
 from series import *
@@ -109,20 +112,20 @@
     assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True)
     assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939,  0.35640177, 0.85724366], floats=True)
 
-def test_ParamsStatisticsWrapper_commoncase(h5f=None):
+def test_SharedParamsStatisticsWrapper_commoncase(h5f=None):
     import numpy.random
 
     if not h5f:
         h5f_path = tempfile.NamedTemporaryFile().name
         h5f = openFile(h5f_path, "w")
 
-    stats = ParamsStatisticsWrapper(new_group_name="params", base_group="/",
+    stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
                                 arrays_names=('b1','b2','b3'), hdf5_file=h5f,
                                 index_names=('epoch','minibatch'))
 
-    b1 = numpy.random.rand(5)
-    b2 = numpy.random.rand(5)
-    b3 = numpy.random.rand(5)
+    b1 = DD({'value':numpy.random.rand(5)})
+    b2 = DD({'value':numpy.random.rand(5)})
+    b3 = DD({'value':numpy.random.rand(5)})
     stats.append((1,1), [b1,b2,b3])
 
     h5f.close()
@@ -132,10 +135,10 @@
     b1_table = h5f.getNode('/params', 'b1')
     b3_table = h5f.getNode('/params', 'b3')
 
-    assert b1_table.cols.mean[0] - numpy.mean(b1) < 1e-3
-    assert b3_table.cols.mean[0] - numpy.mean(b3) < 1e-3
-    assert b1_table.cols.min[0] - numpy.min(b1) < 1e-3
-    assert b3_table.cols.min[0] - numpy.min(b3) < 1e-3
+    assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
+    assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3
+    assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
+    assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
 
 def test_get_desc():
     h5f_path = tempfile.NamedTemporaryFile().name
@@ -166,5 +169,5 @@
     test_ErrorSeries_common_case()
     test_BasicStatisticsSeries_common_case()
     test_AccumulatorSeriesWrapper_common_case()
-    test_ParamsStatisticsWrapper_commoncase()
+    test_SharedParamsStatisticsWrapper_commoncase()