Mercurial > ift6266
diff utils/seriestables/test_series.py @ 220:e172ef73cdc5
Ajouté un paquet de type/value checks à SeriesTables, et finalisé les docstrings. Ajouté 3-4 tests. Légers refactorings ici et là sans conséquences externes.
author | fsavard |
---|---|
date | Thu, 11 Mar 2010 10:48:54 -0500 |
parents | 4c137f16b013 |
children | 0515a8901c6a |
line wrap: on
line diff
--- a/utils/seriestables/test_series.py Wed Mar 10 20:14:20 2010 -0500 +++ b/utils/seriestables/test_series.py Thu Mar 11 10:48:54 2010 -0500 @@ -1,14 +1,17 @@ import tempfile + import numpy import numpy.random from jobman import DD -from tables import * +import tables from series import * import series +################################################# +# Utils def compare_floats(f1,f2): if f1-f2 < 1e-3: @@ -28,12 +31,21 @@ return True +################################################# +# Basic Series class tests + +def test_Series_types(): + pass + +################################################# +# ErrorSeries tests + def test_ErrorSeries_common_case(h5f=None): if not h5f: h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") - validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error", + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", hdf5_file=h5f, index_names=('epoch','minibatch'), title="Validation error indexed by epoch and minibatch") @@ -45,7 +57,7 @@ h5f.close() - h5f = openFile(h5f_path, "r") + h5f = tables.openFile(h5f_path, "r") table = h5f.getNode('/', 'validation_error') @@ -56,7 +68,7 @@ def test_AccumulatorSeriesWrapper_common_case(h5f=None): if not h5f: h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") validation_error = ErrorSeries(error_name="accumulated_validation_error", table_name="accumulated_validation_error", @@ -77,7 +89,7 @@ h5f.close() - h5f = openFile(h5f_path, "r") + h5f = tables.openFile(h5f_path, "r") table = h5f.getNode('/', 'accumulated_validation_error') @@ -88,7 +100,7 @@ def test_BasicStatisticsSeries_common_case(h5f=None): if not h5f: h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") stats_series = BasicStatisticsSeries(table_name="b_vector_statistics", hdf5_file=h5f, index_names=('epoch','minibatch'), @@ -102,7 +114,7 @@ h5f.close() - h5f = openFile(h5f_path, "r") + h5f = tables.openFile(h5f_path, "r") table = h5f.getNode('/', 'b_vector_statistics') @@ -118,7 +130,7 @@ if not h5f: h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", arrays_names=('b1','b2','b3'), hdf5_file=h5f, @@ -131,7 +143,7 @@ h5f.close() - h5f = openFile(h5f_path, "r") + h5f = tables.openFile(h5f_path, "r") b1_table = h5f.getNode('/params', 'b1') b3_table = h5f.getNode('/params', 'b3') @@ -143,7 +155,7 @@ def test_get_desc(): h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") desc = series._get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4")) @@ -164,6 +176,27 @@ assert True +def test_index_to_tuple_floaterror(): + try: + series._index_to_tuple(5.1) + assert False + except TypeError: + assert True + +def test_index_to_tuple_arrayok(): + tpl = series._index_to_tuple([1,2,3]) + assert type(tpl) == tuple and tpl[1] == 2 and tpl[2] == 3 + +def test_index_to_tuple_intbecomestuple(): + tpl = series._index_to_tuple(32) + + assert type(tpl) == tuple and tpl == (32,) + +def test_index_to_tuple_longbecomestuple(): + tpl = series._index_to_tuple(928374928374928L) + + assert type(tpl) == tuple and tpl == (928374928374928L,) + if __name__ == '__main__': import tempfile test_get_desc()