comparison utils/seriestables/test_series.py @ 224:0515a8901c6a

Corrigé un bug avec store_timestamp/cpuclock, et tests pour éviter ce cas
author fsavard
date Thu, 11 Mar 2010 11:52:43 -0500
parents e172ef73cdc5
children bfe20d63f88c
comparison
equal deleted inserted replaced
221:02d9c1279dd8 224:0515a8901c6a
63 63
64 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) 64 assert compare_lists(table.cols.epoch[:], [1,1,2,2])
65 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) 65 assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
66 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) 66 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
67 67
68 def test_ErrorSeries_notimestamp(h5f=None):
69 if not h5f:
70 h5f_path = tempfile.NamedTemporaryFile().name
71 h5f = tables.openFile(h5f_path, "w")
72
73 validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
74 hdf5_file=h5f, index_names=('epoch','minibatch'),
75 title="Validation error indexed by epoch and minibatch",
76 store_timestamp=False)
77
78 # (1,1), (1,2) etc. are (epoch, minibatch) index
79 validation_error.append((1,1), 32.0)
80
81 h5f.close()
82
83 h5f = tables.openFile(h5f_path, "r")
84
85 table = h5f.getNode('/', 'validation_error')
86
87 assert compare_lists(table.cols.epoch[:], [1])
88 assert not ("timestamp" in dir(table.cols))
89 assert "cpuclock" in dir(table.cols)
90
91 def test_ErrorSeries_nocpuclock(h5f=None):
92 if not h5f:
93 h5f_path = tempfile.NamedTemporaryFile().name
94 h5f = tables.openFile(h5f_path, "w")
95
96 validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
97 hdf5_file=h5f, index_names=('epoch','minibatch'),
98 title="Validation error indexed by epoch and minibatch",
99 store_cpuclock=False)
100
101 # (1,1), (1,2) etc. are (epoch, minibatch) index
102 validation_error.append((1,1), 32.0)
103
104 h5f.close()
105
106 h5f = tables.openFile(h5f_path, "r")
107
108 table = h5f.getNode('/', 'validation_error')
109
110 assert compare_lists(table.cols.epoch[:], [1])
111 assert not ("cpuclock" in dir(table.cols))
112 assert "timestamp" in dir(table.cols)
113
68 def test_AccumulatorSeriesWrapper_common_case(h5f=None): 114 def test_AccumulatorSeriesWrapper_common_case(h5f=None):
69 if not h5f: 115 if not h5f:
70 h5f_path = tempfile.NamedTemporaryFile().name 116 h5f_path = tempfile.NamedTemporaryFile().name
71 h5f = tables.openFile(h5f_path, "w") 117 h5f = tables.openFile(h5f_path, "w")
72 118
150 196
151 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 197 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
152 assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 198 assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3
153 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 199 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
154 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 200 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
201
202 def test_SharedParamsStatisticsWrapper_notimestamp(h5f=None):
203 import numpy.random
204
205 if not h5f:
206 h5f_path = tempfile.NamedTemporaryFile().name
207 h5f = tables.openFile(h5f_path, "w")
208
209 stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
210 arrays_names=('b1','b2','b3'), hdf5_file=h5f,
211 index_names=('epoch','minibatch'),
212 store_timestamp=False)
213
214 b1 = DD({'value':numpy.random.rand(5)})
215 b2 = DD({'value':numpy.random.rand(5)})
216 b3 = DD({'value':numpy.random.rand(5)})
217 stats.append((1,1), [b1,b2,b3])
218
219 h5f.close()
220
221 h5f = tables.openFile(h5f_path, "r")
222
223 b1_table = h5f.getNode('/params', 'b1')
224 b3_table = h5f.getNode('/params', 'b3')
225
226 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
227 assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3
228 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
229 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
230
231 assert not ('timestamp' in dir(b1_table.cols))
155 232
156 def test_get_desc(): 233 def test_get_desc():
157 h5f_path = tempfile.NamedTemporaryFile().name 234 h5f_path = tempfile.NamedTemporaryFile().name
158 h5f = tables.openFile(h5f_path, "w") 235 h5f = tables.openFile(h5f_path, "w")
159 236