Mercurial > ift6266
diff utils/seriestables/test_series.py @ 223:8547b0cbe4ff
Branch merge
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Thu, 11 Mar 2010 14:42:54 -0500 |
parents | e172ef73cdc5 |
children | 0515a8901c6a |
line wrap: on
line diff
--- a/utils/seriestables/test_series.py Thu Mar 11 14:41:31 2010 -0500 +++ b/utils/seriestables/test_series.py Thu Mar 11 14:42:54 2010 -0500 @@ -1,13 +1,17 @@ import tempfile + import numpy import numpy.random from jobman import DD -from tables import * +import tables from series import * +import series +################################################# +# Utils def compare_floats(f1,f2): if f1-f2 < 1e-3: @@ -27,12 +31,21 @@ return True +################################################# +# Basic Series class tests + +def test_Series_types(): + pass + +################################################# +# ErrorSeries tests + def test_ErrorSeries_common_case(h5f=None): if not h5f: h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") - validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error", + validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", hdf5_file=h5f, index_names=('epoch','minibatch'), title="Validation error indexed by epoch and minibatch") @@ -44,7 +57,7 @@ h5f.close() - h5f = openFile(h5f_path, "r") + h5f = tables.openFile(h5f_path, "r") table = h5f.getNode('/', 'validation_error') @@ -55,7 +68,7 @@ def test_AccumulatorSeriesWrapper_common_case(h5f=None): if not h5f: h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") validation_error = ErrorSeries(error_name="accumulated_validation_error", table_name="accumulated_validation_error", @@ -76,7 +89,7 @@ h5f.close() - h5f = openFile(h5f_path, "r") + h5f = tables.openFile(h5f_path, "r") table = h5f.getNode('/', 'accumulated_validation_error') @@ -87,7 +100,7 @@ def test_BasicStatisticsSeries_common_case(h5f=None): if not h5f: h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") stats_series = BasicStatisticsSeries(table_name="b_vector_statistics", hdf5_file=h5f, index_names=('epoch','minibatch'), @@ -101,7 +114,7 @@ h5f.close() - h5f = openFile(h5f_path, "r") + h5f = tables.openFile(h5f_path, "r") table = h5f.getNode('/', 'b_vector_statistics') @@ -117,7 +130,7 @@ if not h5f: h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", arrays_names=('b1','b2','b3'), hdf5_file=h5f, @@ -130,7 +143,7 @@ h5f.close() - h5f = openFile(h5f_path, "r") + h5f = tables.openFile(h5f_path, "r") b1_table = h5f.getNode('/params', 'b1') b3_table = h5f.getNode('/params', 'b3') @@ -142,9 +155,9 @@ def test_get_desc(): h5f_path = tempfile.NamedTemporaryFile().name - h5f = openFile(h5f_path, "w") + h5f = tables.openFile(h5f_path, "w") - desc = get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4")) + desc = series._get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4")) mytable = h5f.createTable('/', 'mytable', desc) @@ -163,6 +176,27 @@ assert True +def test_index_to_tuple_floaterror(): + try: + series._index_to_tuple(5.1) + assert False + except TypeError: + assert True + +def test_index_to_tuple_arrayok(): + tpl = series._index_to_tuple([1,2,3]) + assert type(tpl) == tuple and tpl[1] == 2 and tpl[2] == 3 + +def test_index_to_tuple_intbecomestuple(): + tpl = series._index_to_tuple(32) + + assert type(tpl) == tuple and tpl == (32,) + +def test_index_to_tuple_longbecomestuple(): + tpl = series._index_to_tuple(928374928374928L) + + assert type(tpl) == tuple and tpl == (928374928374928L,) + if __name__ == '__main__': import tempfile test_get_desc()