comparison utils/seriestables/test_series.py @ 220:e172ef73cdc5

Ajouté un paquet de type/value checks à SeriesTables, et finalisé les docstrings. Ajouté 3-4 tests. Légers refactorings ici et là sans conséquences externes.
author fsavard
date Thu, 11 Mar 2010 10:48:54 -0500
parents 4c137f16b013
children 0515a8901c6a
comparison
equal deleted inserted replaced
219:cde71d24f235 220:e172ef73cdc5
1 import tempfile 1 import tempfile
2
2 import numpy 3 import numpy
3 import numpy.random 4 import numpy.random
4 5
5 from jobman import DD 6 from jobman import DD
6 7
7 from tables import * 8 import tables
8 9
9 from series import * 10 from series import *
10 import series 11 import series
11 12
13 #################################################
14 # Utils
12 15
13 def compare_floats(f1,f2): 16 def compare_floats(f1,f2):
14 if f1-f2 < 1e-3: 17 if f1-f2 < 1e-3:
15 return True 18 return True
16 return False 19 return False
26 elif el1 != el2: 29 elif el1 != el2:
27 return False 30 return False
28 31
29 return True 32 return True
30 33
34 #################################################
35 # Basic Series class tests
36
37 def test_Series_types():
38 pass
39
40 #################################################
41 # ErrorSeries tests
42
31 def test_ErrorSeries_common_case(h5f=None): 43 def test_ErrorSeries_common_case(h5f=None):
32 if not h5f: 44 if not h5f:
33 h5f_path = tempfile.NamedTemporaryFile().name 45 h5f_path = tempfile.NamedTemporaryFile().name
34 h5f = openFile(h5f_path, "w") 46 h5f = tables.openFile(h5f_path, "w")
35 47
36 validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error", 48 validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
37 hdf5_file=h5f, index_names=('epoch','minibatch'), 49 hdf5_file=h5f, index_names=('epoch','minibatch'),
38 title="Validation error indexed by epoch and minibatch") 50 title="Validation error indexed by epoch and minibatch")
39 51
40 # (1,1), (1,2) etc. are (epoch, minibatch) index 52 # (1,1), (1,2) etc. are (epoch, minibatch) index
41 validation_error.append((1,1), 32.0) 53 validation_error.append((1,1), 32.0)
43 validation_error.append((2,1), 28.0) 55 validation_error.append((2,1), 28.0)
44 validation_error.append((2,2), 26.0) 56 validation_error.append((2,2), 26.0)
45 57
46 h5f.close() 58 h5f.close()
47 59
48 h5f = openFile(h5f_path, "r") 60 h5f = tables.openFile(h5f_path, "r")
49 61
50 table = h5f.getNode('/', 'validation_error') 62 table = h5f.getNode('/', 'validation_error')
51 63
52 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) 64 assert compare_lists(table.cols.epoch[:], [1,1,2,2])
53 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) 65 assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
54 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) 66 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
55 67
56 def test_AccumulatorSeriesWrapper_common_case(h5f=None): 68 def test_AccumulatorSeriesWrapper_common_case(h5f=None):
57 if not h5f: 69 if not h5f:
58 h5f_path = tempfile.NamedTemporaryFile().name 70 h5f_path = tempfile.NamedTemporaryFile().name
59 h5f = openFile(h5f_path, "w") 71 h5f = tables.openFile(h5f_path, "w")
60 72
61 validation_error = ErrorSeries(error_name="accumulated_validation_error", 73 validation_error = ErrorSeries(error_name="accumulated_validation_error",
62 table_name="accumulated_validation_error", 74 table_name="accumulated_validation_error",
63 hdf5_file=h5f, 75 hdf5_file=h5f,
64 index_names=('epoch','minibatch'), 76 index_names=('epoch','minibatch'),
75 accumulator.append((3,1), 24.0) 87 accumulator.append((3,1), 24.0)
76 accumulator.append((3,2), 22.0) 88 accumulator.append((3,2), 22.0)
77 89
78 h5f.close() 90 h5f.close()
79 91
80 h5f = openFile(h5f_path, "r") 92 h5f = tables.openFile(h5f_path, "r")
81 93
82 table = h5f.getNode('/', 'accumulated_validation_error') 94 table = h5f.getNode('/', 'accumulated_validation_error')
83 95
84 assert compare_lists(table.cols.epoch[:], [2,3]) 96 assert compare_lists(table.cols.epoch[:], [2,3])
85 assert compare_lists(table.cols.minibatch[:], [1,2]) 97 assert compare_lists(table.cols.minibatch[:], [1,2])
86 assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True) 98 assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True)
87 99
88 def test_BasicStatisticsSeries_common_case(h5f=None): 100 def test_BasicStatisticsSeries_common_case(h5f=None):
89 if not h5f: 101 if not h5f:
90 h5f_path = tempfile.NamedTemporaryFile().name 102 h5f_path = tempfile.NamedTemporaryFile().name
91 h5f = openFile(h5f_path, "w") 103 h5f = tables.openFile(h5f_path, "w")
92 104
93 stats_series = BasicStatisticsSeries(table_name="b_vector_statistics", 105 stats_series = BasicStatisticsSeries(table_name="b_vector_statistics",
94 hdf5_file=h5f, index_names=('epoch','minibatch'), 106 hdf5_file=h5f, index_names=('epoch','minibatch'),
95 title="Basic statistics for b vector indexed by epoch and minibatch") 107 title="Basic statistics for b vector indexed by epoch and minibatch")
96 108
100 stats_series.append((2,1), [0.18, -0.38, -0.68]) 112 stats_series.append((2,1), [0.18, -0.38, -0.68])
101 stats_series.append((2,2), [0.15, 0.02, 1.9]) 113 stats_series.append((2,2), [0.15, 0.02, 1.9])
102 114
103 h5f.close() 115 h5f.close()
104 116
105 h5f = openFile(h5f_path, "r") 117 h5f = tables.openFile(h5f_path, "r")
106 118
107 table = h5f.getNode('/', 'b_vector_statistics') 119 table = h5f.getNode('/', 'b_vector_statistics')
108 120
109 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) 121 assert compare_lists(table.cols.epoch[:], [1,1,2,2])
110 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) 122 assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
116 def test_SharedParamsStatisticsWrapper_commoncase(h5f=None): 128 def test_SharedParamsStatisticsWrapper_commoncase(h5f=None):
117 import numpy.random 129 import numpy.random
118 130
119 if not h5f: 131 if not h5f:
120 h5f_path = tempfile.NamedTemporaryFile().name 132 h5f_path = tempfile.NamedTemporaryFile().name
121 h5f = openFile(h5f_path, "w") 133 h5f = tables.openFile(h5f_path, "w")
122 134
123 stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", 135 stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
124 arrays_names=('b1','b2','b3'), hdf5_file=h5f, 136 arrays_names=('b1','b2','b3'), hdf5_file=h5f,
125 index_names=('epoch','minibatch')) 137 index_names=('epoch','minibatch'))
126 138
129 b3 = DD({'value':numpy.random.rand(5)}) 141 b3 = DD({'value':numpy.random.rand(5)})
130 stats.append((1,1), [b1,b2,b3]) 142 stats.append((1,1), [b1,b2,b3])
131 143
132 h5f.close() 144 h5f.close()
133 145
134 h5f = openFile(h5f_path, "r") 146 h5f = tables.openFile(h5f_path, "r")
135 147
136 b1_table = h5f.getNode('/params', 'b1') 148 b1_table = h5f.getNode('/params', 'b1')
137 b3_table = h5f.getNode('/params', 'b3') 149 b3_table = h5f.getNode('/params', 'b3')
138 150
139 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 151 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
141 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 153 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
142 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 154 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
143 155
144 def test_get_desc(): 156 def test_get_desc():
145 h5f_path = tempfile.NamedTemporaryFile().name 157 h5f_path = tempfile.NamedTemporaryFile().name
146 h5f = openFile(h5f_path, "w") 158 h5f = tables.openFile(h5f_path, "w")
147 159
148 desc = series._get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4")) 160 desc = series._get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4"))
149 161
150 mytable = h5f.createTable('/', 'mytable', desc) 162 mytable = h5f.createTable('/', 'mytable', desc)
151 163
161 assert False 173 assert False
162 except: 174 except:
163 assert True 175 assert True
164 176
165 assert True 177 assert True
178
179 def test_index_to_tuple_floaterror():
180 try:
181 series._index_to_tuple(5.1)
182 assert False
183 except TypeError:
184 assert True
185
186 def test_index_to_tuple_arrayok():
187 tpl = series._index_to_tuple([1,2,3])
188 assert type(tpl) == tuple and tpl[1] == 2 and tpl[2] == 3
189
190 def test_index_to_tuple_intbecomestuple():
191 tpl = series._index_to_tuple(32)
192
193 assert type(tpl) == tuple and tpl == (32,)
194
195 def test_index_to_tuple_longbecomestuple():
196 tpl = series._index_to_tuple(928374928374928L)
197
198 assert type(tpl) == tuple and tpl == (928374928374928L,)
166 199
167 if __name__ == '__main__': 200 if __name__ == '__main__':
168 import tempfile 201 import tempfile
169 test_get_desc() 202 test_get_desc()
170 test_ErrorSeries_common_case() 203 test_ErrorSeries_common_case()