Mercurial > ift6266
changeset 224:0515a8901c6a
Corrigé un bug avec store_timestamp/cpuclock, et tests pour éviter ce cas
author | fsavard |
---|---|
date | Thu, 11 Mar 2010 11:52:43 -0500 |
parents | 02d9c1279dd8 |
children | eb78a695ad7a |
files | utils/seriestables/series.py utils/seriestables/test_series.py |
diffstat | 2 files changed, 97 insertions(+), 6 deletions(-) [+] |
line wrap: on
line diff
--- a/utils/seriestables/series.py Thu Mar 11 11:08:42 2010 -0500 +++ b/utils/seriestables/series.py Thu Mar 11 11:52:43 2010 -0500 @@ -18,7 +18,7 @@ def _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock, pos=0): toexec = "" - + if store_timestamp: toexec += "\ttimestamp = tables.Time32Col(pos="+str(pos)+")\n" pos += 1 @@ -210,8 +210,11 @@ raise NotImplementedError def _timestamp_cpuclock(self, newrow): - newrow["timestamp"] = time.time() - newrow["cpuclock"] = time.clock() + if self.store_timestamp: + newrow["timestamp"] = time.time() + + if self.store_cpuclock: + newrow["cpuclock"] = time.clock() class DummySeries(): """ @@ -267,7 +270,9 @@ def _create_table(self): table_description = _get_description_with_n_ints_n_floats( \ - self.index_names, (self.error_name,)) + self.index_names, (self.error_name,), + store_timestamp=self.store_timestamp, + store_cpuclock=self.store_cpuclock) self._table = self.hdf5_file.createTable(self.hdf5_group, self.table_name, @@ -540,7 +545,8 @@ ''' def __init__(self, arrays_names, new_group_name, hdf5_file, - base_group='/', index_names=('epoch',), title=""): + base_group='/', index_names=('epoch',), title="", + store_timestamp=True, store_cpuclock=True): """ For other parameters, see Series.__init__ @@ -560,6 +566,12 @@ title : str Here the title is attached to the new group, not a table. + + store_timestamp : bool + Here timestamp and cpuclock are stored in *each* table + + store_cpuclock : bool + Here timestamp and cpuclock are stored in *each* table """ # most other checks done when calling BasicStatisticsSeries @@ -584,7 +596,9 @@ hdf5_file=hdf5_file, index_names=index_names, stats_functions=stats_functions, - hdf5_group=new_group._v_pathname)) + hdf5_group=new_group._v_pathname, + store_timestamp=store_timestamp, + store_cpuclock=store_cpuclock)) SeriesArrayWrapper.__init__(self, base_series_list)
--- a/utils/seriestables/test_series.py Thu Mar 11 11:08:42 2010 -0500 +++ b/utils/seriestables/test_series.py Thu Mar 11 11:52:43 2010 -0500 @@ -65,6 +65,52 @@ assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) +def test_ErrorSeries_notimestamp(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Validation error indexed by epoch and minibatch", + store_timestamp=False) + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append((1,1), 32.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + assert compare_lists(table.cols.epoch[:], [1]) + assert not ("timestamp" in dir(table.cols)) + assert "cpuclock" in dir(table.cols) + +def test_ErrorSeries_nocpuclock(h5f=None): + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", + hdf5_file=h5f, index_names=('epoch','minibatch'), + title="Validation error indexed by epoch and minibatch", + store_cpuclock=False) + + # (1,1), (1,2) etc. are (epoch, minibatch) index + validation_error.append((1,1), 32.0) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + table = h5f.getNode('/', 'validation_error') + + assert compare_lists(table.cols.epoch[:], [1]) + assert not ("cpuclock" in dir(table.cols)) + assert "timestamp" in dir(table.cols) + def test_AccumulatorSeriesWrapper_common_case(h5f=None): if not h5f: h5f_path = tempfile.NamedTemporaryFile().name @@ -153,6 +199,37 @@ assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 +def test_SharedParamsStatisticsWrapper_notimestamp(h5f=None): + import numpy.random + + if not h5f: + h5f_path = tempfile.NamedTemporaryFile().name + h5f = tables.openFile(h5f_path, "w") + + stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", + arrays_names=('b1','b2','b3'), hdf5_file=h5f, + index_names=('epoch','minibatch'), + store_timestamp=False) + + b1 = DD({'value':numpy.random.rand(5)}) + b2 = DD({'value':numpy.random.rand(5)}) + b3 = DD({'value':numpy.random.rand(5)}) + stats.append((1,1), [b1,b2,b3]) + + h5f.close() + + h5f = tables.openFile(h5f_path, "r") + + b1_table = h5f.getNode('/params', 'b1') + b3_table = h5f.getNode('/params', 'b3') + + assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 + assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 + assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 + assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 + + assert not ('timestamp' in dir(b1_table.cols)) + def test_get_desc(): h5f_path = tempfile.NamedTemporaryFile().name h5f = tables.openFile(h5f_path, "w")