changeset 225:eb78a695ad7a

Merge
author fsavard
date Thu, 11 Mar 2010 16:29:39 -0500
parents 0515a8901c6a (diff) 8547b0cbe4ff (current diff)
children bfe20d63f88c
files
diffstat 2 files changed, 97 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/utils/seriestables/series.py	Thu Mar 11 14:42:54 2010 -0500
+++ b/utils/seriestables/series.py	Thu Mar 11 16:29:39 2010 -0500
@@ -18,7 +18,7 @@
 
 def _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock, pos=0):
     toexec = ""
-    
+
     if store_timestamp:
         toexec += "\ttimestamp = tables.Time32Col(pos="+str(pos)+")\n"
         pos += 1
@@ -210,8 +210,11 @@
         raise NotImplementedError
 
     def _timestamp_cpuclock(self, newrow):
-        newrow["timestamp"] = time.time()
-        newrow["cpuclock"] = time.clock()
+        if self.store_timestamp:
+            newrow["timestamp"] = time.time()
+
+        if self.store_cpuclock:
+            newrow["cpuclock"] = time.clock()
 
 class DummySeries():
     """
@@ -267,7 +270,9 @@
 
     def _create_table(self):
        table_description = _get_description_with_n_ints_n_floats( \
-                                  self.index_names, (self.error_name,))
+                                  self.index_names, (self.error_name,),
+                                  store_timestamp=self.store_timestamp,
+                                  store_cpuclock=self.store_cpuclock)
 
        self._table = self.hdf5_file.createTable(self.hdf5_group,
                             self.table_name, 
@@ -540,7 +545,8 @@
     '''
 
     def __init__(self, arrays_names, new_group_name, hdf5_file,
-                    base_group='/', index_names=('epoch',), title=""):
+                    base_group='/', index_names=('epoch',), title="",
+                    store_timestamp=True, store_cpuclock=True):
         """
         For other parameters, see Series.__init__
 
@@ -560,6 +566,12 @@
 
         title : str
             Here the title is attached to the new group, not a table.
+
+        store_timestamp : bool
+            Here timestamp and cpuclock are stored in *each* table
+
+        store_cpuclock : bool
+            Here timestamp and cpuclock are stored in *each* table
         """
 
         # most other checks done when calling BasicStatisticsSeries
@@ -584,7 +596,9 @@
                                 hdf5_file=hdf5_file,
                                 index_names=index_names,
                                 stats_functions=stats_functions,
-                                hdf5_group=new_group._v_pathname))
+                                hdf5_group=new_group._v_pathname,
+                                store_timestamp=store_timestamp,
+                                store_cpuclock=store_cpuclock))
 
         SeriesArrayWrapper.__init__(self, base_series_list)
 
--- a/utils/seriestables/test_series.py	Thu Mar 11 14:42:54 2010 -0500
+++ b/utils/seriestables/test_series.py	Thu Mar 11 16:29:39 2010 -0500
@@ -65,6 +65,52 @@
     assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
     assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
 
+def test_ErrorSeries_notimestamp(h5f=None):
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = tables.openFile(h5f_path, "w")
+
+    validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
+                                hdf5_file=h5f, index_names=('epoch','minibatch'),
+                                title="Validation error indexed by epoch and minibatch", 
+                                store_timestamp=False)
+
+    # (1,1), (1,2) etc. are (epoch, minibatch) index
+    validation_error.append((1,1), 32.0)
+
+    h5f.close()
+
+    h5f = tables.openFile(h5f_path, "r")
+    
+    table = h5f.getNode('/', 'validation_error')
+
+    assert compare_lists(table.cols.epoch[:], [1])
+    assert not ("timestamp" in dir(table.cols))
+    assert "cpuclock" in dir(table.cols)
+
+def test_ErrorSeries_nocpuclock(h5f=None):
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = tables.openFile(h5f_path, "w")
+
+    validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
+                                hdf5_file=h5f, index_names=('epoch','minibatch'),
+                                title="Validation error indexed by epoch and minibatch", 
+                                store_cpuclock=False)
+
+    # (1,1), (1,2) etc. are (epoch, minibatch) index
+    validation_error.append((1,1), 32.0)
+
+    h5f.close()
+
+    h5f = tables.openFile(h5f_path, "r")
+    
+    table = h5f.getNode('/', 'validation_error')
+
+    assert compare_lists(table.cols.epoch[:], [1])
+    assert not ("cpuclock" in dir(table.cols))
+    assert "timestamp" in dir(table.cols)
+
 def test_AccumulatorSeriesWrapper_common_case(h5f=None):
     if not h5f:
         h5f_path = tempfile.NamedTemporaryFile().name
@@ -153,6 +199,37 @@
     assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
     assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
 
+def test_SharedParamsStatisticsWrapper_notimestamp(h5f=None):
+    import numpy.random
+
+    if not h5f:
+        h5f_path = tempfile.NamedTemporaryFile().name
+        h5f = tables.openFile(h5f_path, "w")
+
+    stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
+                                arrays_names=('b1','b2','b3'), hdf5_file=h5f,
+                                index_names=('epoch','minibatch'),
+                                store_timestamp=False)
+
+    b1 = DD({'value':numpy.random.rand(5)})
+    b2 = DD({'value':numpy.random.rand(5)})
+    b3 = DD({'value':numpy.random.rand(5)})
+    stats.append((1,1), [b1,b2,b3])
+
+    h5f.close()
+
+    h5f = tables.openFile(h5f_path, "r")
+
+    b1_table = h5f.getNode('/params', 'b1')
+    b3_table = h5f.getNode('/params', 'b3')
+
+    assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
+    assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3
+    assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
+    assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
+
+    assert not ('timestamp' in dir(b1_table.cols))
+
 def test_get_desc():
     h5f_path = tempfile.NamedTemporaryFile().name
     h5f = tables.openFile(h5f_path, "w")