diff utils/seriestables/test_series.py @ 220:e172ef73cdc5

Ajouté un paquet de type/value checks à SeriesTables, et finalisé les docstrings. Ajouté 3-4 tests. Légers refactorings ici et là sans conséquences externes.
author fsavard
date Thu, 11 Mar 2010 10:48:54 -0500
parents 4c137f16b013
children 0515a8901c6a
line wrap: on
line diff
--- a/utils/seriestables/test_series.py	Wed Mar 10 20:14:20 2010 -0500
+++ b/utils/seriestables/test_series.py	Thu Mar 11 10:48:54 2010 -0500
@@ -1,14 +1,17 @@
 import tempfile
+
 import numpy
 import numpy.random
 
 from jobman import DD
 
-from tables import *
+import tables
 
 from series import *
 import series
 
+#################################################
+# Utils
 
 def compare_floats(f1,f2):
     if f1-f2 < 1e-3:
@@ -28,12 +31,21 @@
 
     return True
 
+#################################################
+# Basic Series class tests
+
+def test_Series_types():
+    pass
+
+#################################################
+# ErrorSeries tests
+
 def test_ErrorSeries_common_case(h5f=None):
     if not h5f:
         h5f_path = tempfile.NamedTemporaryFile().name
-        h5f = openFile(h5f_path, "w")
+        h5f = tables.openFile(h5f_path, "w")
 
-    validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error",
+    validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
                                 hdf5_file=h5f, index_names=('epoch','minibatch'),
                                 title="Validation error indexed by epoch and minibatch")
 
@@ -45,7 +57,7 @@
 
     h5f.close()
 
-    h5f = openFile(h5f_path, "r")
+    h5f = tables.openFile(h5f_path, "r")
     
     table = h5f.getNode('/', 'validation_error')
 
@@ -56,7 +68,7 @@
 def test_AccumulatorSeriesWrapper_common_case(h5f=None):
     if not h5f:
         h5f_path = tempfile.NamedTemporaryFile().name
-        h5f = openFile(h5f_path, "w")
+        h5f = tables.openFile(h5f_path, "w")
 
     validation_error = ErrorSeries(error_name="accumulated_validation_error",
                                 table_name="accumulated_validation_error",
@@ -77,7 +89,7 @@
 
     h5f.close()
 
-    h5f = openFile(h5f_path, "r")
+    h5f = tables.openFile(h5f_path, "r")
     
     table = h5f.getNode('/', 'accumulated_validation_error')
 
@@ -88,7 +100,7 @@
 def test_BasicStatisticsSeries_common_case(h5f=None):
     if not h5f:
         h5f_path = tempfile.NamedTemporaryFile().name
-        h5f = openFile(h5f_path, "w")
+        h5f = tables.openFile(h5f_path, "w")
 
     stats_series = BasicStatisticsSeries(table_name="b_vector_statistics",
                                 hdf5_file=h5f, index_names=('epoch','minibatch'),
@@ -102,7 +114,7 @@
 
     h5f.close()
 
-    h5f = openFile(h5f_path, "r")
+    h5f = tables.openFile(h5f_path, "r")
     
     table = h5f.getNode('/', 'b_vector_statistics')
 
@@ -118,7 +130,7 @@
 
     if not h5f:
         h5f_path = tempfile.NamedTemporaryFile().name
-        h5f = openFile(h5f_path, "w")
+        h5f = tables.openFile(h5f_path, "w")
 
     stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
                                 arrays_names=('b1','b2','b3'), hdf5_file=h5f,
@@ -131,7 +143,7 @@
 
     h5f.close()
 
-    h5f = openFile(h5f_path, "r")
+    h5f = tables.openFile(h5f_path, "r")
 
     b1_table = h5f.getNode('/params', 'b1')
     b3_table = h5f.getNode('/params', 'b3')
@@ -143,7 +155,7 @@
 
 def test_get_desc():
     h5f_path = tempfile.NamedTemporaryFile().name
-    h5f = openFile(h5f_path, "w")
+    h5f = tables.openFile(h5f_path, "w")
 
     desc = series._get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4"))
 
@@ -164,6 +176,27 @@
 
     assert True
 
+def test_index_to_tuple_floaterror():
+    try:
+        series._index_to_tuple(5.1)
+        assert False
+    except TypeError:
+        assert True
+
+def test_index_to_tuple_arrayok():
+    tpl = series._index_to_tuple([1,2,3])
+    assert type(tpl) == tuple and tpl[1] == 2 and tpl[2] == 3
+
+def test_index_to_tuple_intbecomestuple():
+    tpl = series._index_to_tuple(32)
+
+    assert type(tpl) == tuple and tpl == (32,)
+
+def test_index_to_tuple_longbecomestuple():
+    tpl = series._index_to_tuple(928374928374928L)
+
+    assert type(tpl) == tuple and tpl == (928374928374928L,)
+
 if __name__ == '__main__':
     import tempfile
     test_get_desc()