Mercurial > ift6266
comparison utils/tables_series/test_series.py @ 210:dc0d77c8a878
Commented table_series code, changed ParamsStatisticsArray to take shared params instead, create DummySeries to use when we don't want to save a named series
author | savardf |
---|---|
date | Tue, 09 Mar 2010 10:15:19 -0500 |
parents | acb942530923 |
children |
comparison
equal
deleted
inserted
replaced
209:d982dfa583df | 210:dc0d77c8a878 |
---|---|
1 import tempfile | 1 import tempfile |
2 import numpy | 2 import numpy |
3 import numpy.random | 3 import numpy.random |
4 | |
5 from jobman import DD | |
6 | |
4 from tables import * | 7 from tables import * |
5 | 8 |
6 from series import * | 9 from series import * |
7 | 10 |
8 | 11 |
107 assert compare_lists(table.cols.mean[:], [0.21666667, 0.23333333, -0.29333332, 0.69], floats=True) | 110 assert compare_lists(table.cols.mean[:], [0.21666667, 0.23333333, -0.29333332, 0.69], floats=True) |
108 assert compare_lists(table.cols.min[:], [0.15000001, -0.18000001, -0.68000001, 0.02], floats=True) | 111 assert compare_lists(table.cols.min[:], [0.15000001, -0.18000001, -0.68000001, 0.02], floats=True) |
109 assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True) | 112 assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True) |
110 assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939, 0.35640177, 0.85724366], floats=True) | 113 assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939, 0.35640177, 0.85724366], floats=True) |
111 | 114 |
112 def test_ParamsStatisticsWrapper_commoncase(h5f=None): | 115 def test_SharedParamsStatisticsWrapper_commoncase(h5f=None): |
113 import numpy.random | 116 import numpy.random |
114 | 117 |
115 if not h5f: | 118 if not h5f: |
116 h5f_path = tempfile.NamedTemporaryFile().name | 119 h5f_path = tempfile.NamedTemporaryFile().name |
117 h5f = openFile(h5f_path, "w") | 120 h5f = openFile(h5f_path, "w") |
118 | 121 |
119 stats = ParamsStatisticsWrapper(new_group_name="params", base_group="/", | 122 stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", |
120 arrays_names=('b1','b2','b3'), hdf5_file=h5f, | 123 arrays_names=('b1','b2','b3'), hdf5_file=h5f, |
121 index_names=('epoch','minibatch')) | 124 index_names=('epoch','minibatch')) |
122 | 125 |
123 b1 = numpy.random.rand(5) | 126 b1 = DD({'value':numpy.random.rand(5)}) |
124 b2 = numpy.random.rand(5) | 127 b2 = DD({'value':numpy.random.rand(5)}) |
125 b3 = numpy.random.rand(5) | 128 b3 = DD({'value':numpy.random.rand(5)}) |
126 stats.append((1,1), [b1,b2,b3]) | 129 stats.append((1,1), [b1,b2,b3]) |
127 | 130 |
128 h5f.close() | 131 h5f.close() |
129 | 132 |
130 h5f = openFile(h5f_path, "r") | 133 h5f = openFile(h5f_path, "r") |
131 | 134 |
132 b1_table = h5f.getNode('/params', 'b1') | 135 b1_table = h5f.getNode('/params', 'b1') |
133 b3_table = h5f.getNode('/params', 'b3') | 136 b3_table = h5f.getNode('/params', 'b3') |
134 | 137 |
135 assert b1_table.cols.mean[0] - numpy.mean(b1) < 1e-3 | 138 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 |
136 assert b3_table.cols.mean[0] - numpy.mean(b3) < 1e-3 | 139 assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 |
137 assert b1_table.cols.min[0] - numpy.min(b1) < 1e-3 | 140 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 |
138 assert b3_table.cols.min[0] - numpy.min(b3) < 1e-3 | 141 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 |
139 | 142 |
140 def test_get_desc(): | 143 def test_get_desc(): |
141 h5f_path = tempfile.NamedTemporaryFile().name | 144 h5f_path = tempfile.NamedTemporaryFile().name |
142 h5f = openFile(h5f_path, "w") | 145 h5f = openFile(h5f_path, "w") |
143 | 146 |
164 import tempfile | 167 import tempfile |
165 test_get_desc() | 168 test_get_desc() |
166 test_ErrorSeries_common_case() | 169 test_ErrorSeries_common_case() |
167 test_BasicStatisticsSeries_common_case() | 170 test_BasicStatisticsSeries_common_case() |
168 test_AccumulatorSeriesWrapper_common_case() | 171 test_AccumulatorSeriesWrapper_common_case() |
169 test_ParamsStatisticsWrapper_commoncase() | 172 test_SharedParamsStatisticsWrapper_commoncase() |
170 | 173 |