Mercurial > ift6266
comparison utils/seriestables/test_series.py @ 220:e172ef73cdc5
Ajouté un paquet de type/value checks à SeriesTables, et finalisé les docstrings. Ajouté 3-4 tests. Légers refactorings ici et là sans conséquences externes.
author | fsavard |
---|---|
date | Thu, 11 Mar 2010 10:48:54 -0500 |
parents | 4c137f16b013 |
children | 0515a8901c6a |
comparison
equal
deleted
inserted
replaced
219:cde71d24f235 | 220:e172ef73cdc5 |
---|---|
1 import tempfile | 1 import tempfile |
2 | |
2 import numpy | 3 import numpy |
3 import numpy.random | 4 import numpy.random |
4 | 5 |
5 from jobman import DD | 6 from jobman import DD |
6 | 7 |
7 from tables import * | 8 import tables |
8 | 9 |
9 from series import * | 10 from series import * |
10 import series | 11 import series |
11 | 12 |
13 ################################################# | |
14 # Utils | |
12 | 15 |
13 def compare_floats(f1,f2): | 16 def compare_floats(f1,f2): |
14 if f1-f2 < 1e-3: | 17 if f1-f2 < 1e-3: |
15 return True | 18 return True |
16 return False | 19 return False |
26 elif el1 != el2: | 29 elif el1 != el2: |
27 return False | 30 return False |
28 | 31 |
29 return True | 32 return True |
30 | 33 |
34 ################################################# | |
35 # Basic Series class tests | |
36 | |
37 def test_Series_types(): | |
38 pass | |
39 | |
40 ################################################# | |
41 # ErrorSeries tests | |
42 | |
31 def test_ErrorSeries_common_case(h5f=None): | 43 def test_ErrorSeries_common_case(h5f=None): |
32 if not h5f: | 44 if not h5f: |
33 h5f_path = tempfile.NamedTemporaryFile().name | 45 h5f_path = tempfile.NamedTemporaryFile().name |
34 h5f = openFile(h5f_path, "w") | 46 h5f = tables.openFile(h5f_path, "w") |
35 | 47 |
36 validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error", | 48 validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error", |
37 hdf5_file=h5f, index_names=('epoch','minibatch'), | 49 hdf5_file=h5f, index_names=('epoch','minibatch'), |
38 title="Validation error indexed by epoch and minibatch") | 50 title="Validation error indexed by epoch and minibatch") |
39 | 51 |
40 # (1,1), (1,2) etc. are (epoch, minibatch) index | 52 # (1,1), (1,2) etc. are (epoch, minibatch) index |
41 validation_error.append((1,1), 32.0) | 53 validation_error.append((1,1), 32.0) |
43 validation_error.append((2,1), 28.0) | 55 validation_error.append((2,1), 28.0) |
44 validation_error.append((2,2), 26.0) | 56 validation_error.append((2,2), 26.0) |
45 | 57 |
46 h5f.close() | 58 h5f.close() |
47 | 59 |
48 h5f = openFile(h5f_path, "r") | 60 h5f = tables.openFile(h5f_path, "r") |
49 | 61 |
50 table = h5f.getNode('/', 'validation_error') | 62 table = h5f.getNode('/', 'validation_error') |
51 | 63 |
52 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) | 64 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) |
53 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) | 65 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) |
54 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) | 66 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) |
55 | 67 |
56 def test_AccumulatorSeriesWrapper_common_case(h5f=None): | 68 def test_AccumulatorSeriesWrapper_common_case(h5f=None): |
57 if not h5f: | 69 if not h5f: |
58 h5f_path = tempfile.NamedTemporaryFile().name | 70 h5f_path = tempfile.NamedTemporaryFile().name |
59 h5f = openFile(h5f_path, "w") | 71 h5f = tables.openFile(h5f_path, "w") |
60 | 72 |
61 validation_error = ErrorSeries(error_name="accumulated_validation_error", | 73 validation_error = ErrorSeries(error_name="accumulated_validation_error", |
62 table_name="accumulated_validation_error", | 74 table_name="accumulated_validation_error", |
63 hdf5_file=h5f, | 75 hdf5_file=h5f, |
64 index_names=('epoch','minibatch'), | 76 index_names=('epoch','minibatch'), |
75 accumulator.append((3,1), 24.0) | 87 accumulator.append((3,1), 24.0) |
76 accumulator.append((3,2), 22.0) | 88 accumulator.append((3,2), 22.0) |
77 | 89 |
78 h5f.close() | 90 h5f.close() |
79 | 91 |
80 h5f = openFile(h5f_path, "r") | 92 h5f = tables.openFile(h5f_path, "r") |
81 | 93 |
82 table = h5f.getNode('/', 'accumulated_validation_error') | 94 table = h5f.getNode('/', 'accumulated_validation_error') |
83 | 95 |
84 assert compare_lists(table.cols.epoch[:], [2,3]) | 96 assert compare_lists(table.cols.epoch[:], [2,3]) |
85 assert compare_lists(table.cols.minibatch[:], [1,2]) | 97 assert compare_lists(table.cols.minibatch[:], [1,2]) |
86 assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True) | 98 assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True) |
87 | 99 |
88 def test_BasicStatisticsSeries_common_case(h5f=None): | 100 def test_BasicStatisticsSeries_common_case(h5f=None): |
89 if not h5f: | 101 if not h5f: |
90 h5f_path = tempfile.NamedTemporaryFile().name | 102 h5f_path = tempfile.NamedTemporaryFile().name |
91 h5f = openFile(h5f_path, "w") | 103 h5f = tables.openFile(h5f_path, "w") |
92 | 104 |
93 stats_series = BasicStatisticsSeries(table_name="b_vector_statistics", | 105 stats_series = BasicStatisticsSeries(table_name="b_vector_statistics", |
94 hdf5_file=h5f, index_names=('epoch','minibatch'), | 106 hdf5_file=h5f, index_names=('epoch','minibatch'), |
95 title="Basic statistics for b vector indexed by epoch and minibatch") | 107 title="Basic statistics for b vector indexed by epoch and minibatch") |
96 | 108 |
100 stats_series.append((2,1), [0.18, -0.38, -0.68]) | 112 stats_series.append((2,1), [0.18, -0.38, -0.68]) |
101 stats_series.append((2,2), [0.15, 0.02, 1.9]) | 113 stats_series.append((2,2), [0.15, 0.02, 1.9]) |
102 | 114 |
103 h5f.close() | 115 h5f.close() |
104 | 116 |
105 h5f = openFile(h5f_path, "r") | 117 h5f = tables.openFile(h5f_path, "r") |
106 | 118 |
107 table = h5f.getNode('/', 'b_vector_statistics') | 119 table = h5f.getNode('/', 'b_vector_statistics') |
108 | 120 |
109 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) | 121 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) |
110 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) | 122 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) |
116 def test_SharedParamsStatisticsWrapper_commoncase(h5f=None): | 128 def test_SharedParamsStatisticsWrapper_commoncase(h5f=None): |
117 import numpy.random | 129 import numpy.random |
118 | 130 |
119 if not h5f: | 131 if not h5f: |
120 h5f_path = tempfile.NamedTemporaryFile().name | 132 h5f_path = tempfile.NamedTemporaryFile().name |
121 h5f = openFile(h5f_path, "w") | 133 h5f = tables.openFile(h5f_path, "w") |
122 | 134 |
123 stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", | 135 stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", |
124 arrays_names=('b1','b2','b3'), hdf5_file=h5f, | 136 arrays_names=('b1','b2','b3'), hdf5_file=h5f, |
125 index_names=('epoch','minibatch')) | 137 index_names=('epoch','minibatch')) |
126 | 138 |
129 b3 = DD({'value':numpy.random.rand(5)}) | 141 b3 = DD({'value':numpy.random.rand(5)}) |
130 stats.append((1,1), [b1,b2,b3]) | 142 stats.append((1,1), [b1,b2,b3]) |
131 | 143 |
132 h5f.close() | 144 h5f.close() |
133 | 145 |
134 h5f = openFile(h5f_path, "r") | 146 h5f = tables.openFile(h5f_path, "r") |
135 | 147 |
136 b1_table = h5f.getNode('/params', 'b1') | 148 b1_table = h5f.getNode('/params', 'b1') |
137 b3_table = h5f.getNode('/params', 'b3') | 149 b3_table = h5f.getNode('/params', 'b3') |
138 | 150 |
139 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 | 151 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 |
141 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 | 153 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 |
142 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 | 154 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 |
143 | 155 |
144 def test_get_desc(): | 156 def test_get_desc(): |
145 h5f_path = tempfile.NamedTemporaryFile().name | 157 h5f_path = tempfile.NamedTemporaryFile().name |
146 h5f = openFile(h5f_path, "w") | 158 h5f = tables.openFile(h5f_path, "w") |
147 | 159 |
148 desc = series._get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4")) | 160 desc = series._get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4")) |
149 | 161 |
150 mytable = h5f.createTable('/', 'mytable', desc) | 162 mytable = h5f.createTable('/', 'mytable', desc) |
151 | 163 |
161 assert False | 173 assert False |
162 except: | 174 except: |
163 assert True | 175 assert True |
164 | 176 |
165 assert True | 177 assert True |
178 | |
179 def test_index_to_tuple_floaterror(): | |
180 try: | |
181 series._index_to_tuple(5.1) | |
182 assert False | |
183 except TypeError: | |
184 assert True | |
185 | |
186 def test_index_to_tuple_arrayok(): | |
187 tpl = series._index_to_tuple([1,2,3]) | |
188 assert type(tpl) == tuple and tpl[1] == 2 and tpl[2] == 3 | |
189 | |
190 def test_index_to_tuple_intbecomestuple(): | |
191 tpl = series._index_to_tuple(32) | |
192 | |
193 assert type(tpl) == tuple and tpl == (32,) | |
194 | |
195 def test_index_to_tuple_longbecomestuple(): | |
196 tpl = series._index_to_tuple(928374928374928L) | |
197 | |
198 assert type(tpl) == tuple and tpl == (928374928374928L,) | |
166 | 199 |
167 if __name__ == '__main__': | 200 if __name__ == '__main__': |
168 import tempfile | 201 import tempfile |
169 test_get_desc() | 202 test_get_desc() |
170 test_ErrorSeries_common_case() | 203 test_ErrorSeries_common_case() |