comparison utils/seriestables/test_series.py @ 213:a96fa4de06d2

Renommé mon module de séries
author fsavard
date Wed, 10 Mar 2010 16:52:22 -0500
parents utils/tables_series/test_series.py@dc0d77c8a878
children 4c137f16b013
comparison
equal deleted inserted replaced
212:e390b0454515 213:a96fa4de06d2
1 import tempfile
2 import numpy
3 import numpy.random
4
5 from jobman import DD
6
7 from tables import *
8
9 from series import *
10
11
12 def compare_floats(f1,f2):
13 if f1-f2 < 1e-3:
14 return True
15 return False
16
17 def compare_lists(it1, it2, floats=False):
18 if len(it1) != len(it2):
19 return False
20
21 for el1, el2 in zip(it1, it2):
22 if floats:
23 if not compare_floats(el1,el2):
24 return False
25 elif el1 != el2:
26 return False
27
28 return True
29
30 def test_ErrorSeries_common_case(h5f=None):
31 if not h5f:
32 h5f_path = tempfile.NamedTemporaryFile().name
33 h5f = openFile(h5f_path, "w")
34
35 validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error",
36 hdf5_file=h5f, index_names=('epoch','minibatch'),
37 title="Validation error indexed by epoch and minibatch")
38
39 # (1,1), (1,2) etc. are (epoch, minibatch) index
40 validation_error.append((1,1), 32.0)
41 validation_error.append((1,2), 30.0)
42 validation_error.append((2,1), 28.0)
43 validation_error.append((2,2), 26.0)
44
45 h5f.close()
46
47 h5f = openFile(h5f_path, "r")
48
49 table = h5f.getNode('/', 'validation_error')
50
51 assert compare_lists(table.cols.epoch[:], [1,1,2,2])
52 assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
53 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
54
55 def test_AccumulatorSeriesWrapper_common_case(h5f=None):
56 if not h5f:
57 h5f_path = tempfile.NamedTemporaryFile().name
58 h5f = openFile(h5f_path, "w")
59
60 validation_error = ErrorSeries(error_name="accumulated_validation_error",
61 table_name="accumulated_validation_error",
62 hdf5_file=h5f,
63 index_names=('epoch','minibatch'),
64 title="Validation error, summed every 3 minibatches, indexed by epoch and minibatch")
65
66 accumulator = AccumulatorSeriesWrapper(base_series=validation_error,
67 reduce_every=3, reduce_function=numpy.sum)
68
69 # (1,1), (1,2) etc. are (epoch, minibatch) index
70 accumulator.append((1,1), 32.0)
71 accumulator.append((1,2), 30.0)
72 accumulator.append((2,1), 28.0)
73 accumulator.append((2,2), 26.0)
74 accumulator.append((3,1), 24.0)
75 accumulator.append((3,2), 22.0)
76
77 h5f.close()
78
79 h5f = openFile(h5f_path, "r")
80
81 table = h5f.getNode('/', 'accumulated_validation_error')
82
83 assert compare_lists(table.cols.epoch[:], [2,3])
84 assert compare_lists(table.cols.minibatch[:], [1,2])
85 assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True)
86
87 def test_BasicStatisticsSeries_common_case(h5f=None):
88 if not h5f:
89 h5f_path = tempfile.NamedTemporaryFile().name
90 h5f = openFile(h5f_path, "w")
91
92 stats_series = BasicStatisticsSeries(table_name="b_vector_statistics",
93 hdf5_file=h5f, index_names=('epoch','minibatch'),
94 title="Basic statistics for b vector indexed by epoch and minibatch")
95
96 # (1,1), (1,2) etc. are (epoch, minibatch) index
97 stats_series.append((1,1), [0.15, 0.20, 0.30])
98 stats_series.append((1,2), [-0.18, 0.30, 0.58])
99 stats_series.append((2,1), [0.18, -0.38, -0.68])
100 stats_series.append((2,2), [0.15, 0.02, 1.9])
101
102 h5f.close()
103
104 h5f = openFile(h5f_path, "r")
105
106 table = h5f.getNode('/', 'b_vector_statistics')
107
108 assert compare_lists(table.cols.epoch[:], [1,1,2,2])
109 assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
110 assert compare_lists(table.cols.mean[:], [0.21666667, 0.23333333, -0.29333332, 0.69], floats=True)
111 assert compare_lists(table.cols.min[:], [0.15000001, -0.18000001, -0.68000001, 0.02], floats=True)
112 assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True)
113 assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939, 0.35640177, 0.85724366], floats=True)
114
115 def test_SharedParamsStatisticsWrapper_commoncase(h5f=None):
116 import numpy.random
117
118 if not h5f:
119 h5f_path = tempfile.NamedTemporaryFile().name
120 h5f = openFile(h5f_path, "w")
121
122 stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
123 arrays_names=('b1','b2','b3'), hdf5_file=h5f,
124 index_names=('epoch','minibatch'))
125
126 b1 = DD({'value':numpy.random.rand(5)})
127 b2 = DD({'value':numpy.random.rand(5)})
128 b3 = DD({'value':numpy.random.rand(5)})
129 stats.append((1,1), [b1,b2,b3])
130
131 h5f.close()
132
133 h5f = openFile(h5f_path, "r")
134
135 b1_table = h5f.getNode('/params', 'b1')
136 b3_table = h5f.getNode('/params', 'b3')
137
138 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
139 assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3
140 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
141 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3
142
143 def test_get_desc():
144 h5f_path = tempfile.NamedTemporaryFile().name
145 h5f = openFile(h5f_path, "w")
146
147 desc = get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4"))
148
149 mytable = h5f.createTable('/', 'mytable', desc)
150
151 # just make sure the columns are there... otherwise this will throw an exception
152 mytable.cols.col1
153 mytable.cols.col2
154 mytable.cols.col3
155 mytable.cols.col4
156
157 try:
158 # this should fail... LocalDescription must be local to get_desc_etc
159 test = LocalDescription
160 assert False
161 except:
162 assert True
163
164 assert True
165
166 if __name__ == '__main__':
167 import tempfile
168 test_get_desc()
169 test_ErrorSeries_common_case()
170 test_BasicStatisticsSeries_common_case()
171 test_AccumulatorSeriesWrapper_common_case()
172 test_SharedParamsStatisticsWrapper_commoncase()
173