Mercurial > ift6266
comparison utils/seriestables/test_series.py @ 224:0515a8901c6a
Corrigé un bug avec store_timestamp/cpuclock, et tests pour éviter ce cas
author | fsavard |
---|---|
date | Thu, 11 Mar 2010 11:52:43 -0500 |
parents | e172ef73cdc5 |
children | bfe20d63f88c |
comparison
equal
deleted
inserted
replaced
221:02d9c1279dd8 | 224:0515a8901c6a |
---|---|
63 | 63 |
64 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) | 64 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) |
65 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) | 65 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) |
66 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) | 66 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) |
67 | 67 |
68 def test_ErrorSeries_notimestamp(h5f=None): | |
69 if not h5f: | |
70 h5f_path = tempfile.NamedTemporaryFile().name | |
71 h5f = tables.openFile(h5f_path, "w") | |
72 | |
73 validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", | |
74 hdf5_file=h5f, index_names=('epoch','minibatch'), | |
75 title="Validation error indexed by epoch and minibatch", | |
76 store_timestamp=False) | |
77 | |
78 # (1,1), (1,2) etc. are (epoch, minibatch) index | |
79 validation_error.append((1,1), 32.0) | |
80 | |
81 h5f.close() | |
82 | |
83 h5f = tables.openFile(h5f_path, "r") | |
84 | |
85 table = h5f.getNode('/', 'validation_error') | |
86 | |
87 assert compare_lists(table.cols.epoch[:], [1]) | |
88 assert not ("timestamp" in dir(table.cols)) | |
89 assert "cpuclock" in dir(table.cols) | |
90 | |
91 def test_ErrorSeries_nocpuclock(h5f=None): | |
92 if not h5f: | |
93 h5f_path = tempfile.NamedTemporaryFile().name | |
94 h5f = tables.openFile(h5f_path, "w") | |
95 | |
96 validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", | |
97 hdf5_file=h5f, index_names=('epoch','minibatch'), | |
98 title="Validation error indexed by epoch and minibatch", | |
99 store_cpuclock=False) | |
100 | |
101 # (1,1), (1,2) etc. are (epoch, minibatch) index | |
102 validation_error.append((1,1), 32.0) | |
103 | |
104 h5f.close() | |
105 | |
106 h5f = tables.openFile(h5f_path, "r") | |
107 | |
108 table = h5f.getNode('/', 'validation_error') | |
109 | |
110 assert compare_lists(table.cols.epoch[:], [1]) | |
111 assert not ("cpuclock" in dir(table.cols)) | |
112 assert "timestamp" in dir(table.cols) | |
113 | |
68 def test_AccumulatorSeriesWrapper_common_case(h5f=None): | 114 def test_AccumulatorSeriesWrapper_common_case(h5f=None): |
69 if not h5f: | 115 if not h5f: |
70 h5f_path = tempfile.NamedTemporaryFile().name | 116 h5f_path = tempfile.NamedTemporaryFile().name |
71 h5f = tables.openFile(h5f_path, "w") | 117 h5f = tables.openFile(h5f_path, "w") |
72 | 118 |
150 | 196 |
151 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 | 197 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 |
152 assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 | 198 assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 |
153 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 | 199 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 |
154 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 | 200 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 |
201 | |
202 def test_SharedParamsStatisticsWrapper_notimestamp(h5f=None): | |
203 import numpy.random | |
204 | |
205 if not h5f: | |
206 h5f_path = tempfile.NamedTemporaryFile().name | |
207 h5f = tables.openFile(h5f_path, "w") | |
208 | |
209 stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", | |
210 arrays_names=('b1','b2','b3'), hdf5_file=h5f, | |
211 index_names=('epoch','minibatch'), | |
212 store_timestamp=False) | |
213 | |
214 b1 = DD({'value':numpy.random.rand(5)}) | |
215 b2 = DD({'value':numpy.random.rand(5)}) | |
216 b3 = DD({'value':numpy.random.rand(5)}) | |
217 stats.append((1,1), [b1,b2,b3]) | |
218 | |
219 h5f.close() | |
220 | |
221 h5f = tables.openFile(h5f_path, "r") | |
222 | |
223 b1_table = h5f.getNode('/params', 'b1') | |
224 b3_table = h5f.getNode('/params', 'b3') | |
225 | |
226 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 | |
227 assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 | |
228 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 | |
229 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 | |
230 | |
231 assert not ('timestamp' in dir(b1_table.cols)) | |
155 | 232 |
156 def test_get_desc(): | 233 def test_get_desc(): |
157 h5f_path = tempfile.NamedTemporaryFile().name | 234 h5f_path = tempfile.NamedTemporaryFile().name |
158 h5f = tables.openFile(h5f_path, "w") | 235 h5f = tables.openFile(h5f_path, "w") |
159 | 236 |