Mercurial > ift6266
comparison utils/seriestables/test_series.py @ 217:de3aef84714a
merge, second try
author | Dumitru Erhan <dumitru.erhan@gmail.com> |
---|---|
date | Wed, 10 Mar 2010 17:08:50 -0500 |
parents | a96fa4de06d2 |
children | 4c137f16b013 |
comparison
equal
deleted
inserted
replaced
216:c89004f9cab2 | 217:de3aef84714a |
---|---|
1 import tempfile | |
2 import numpy | |
3 import numpy.random | |
4 | |
5 from jobman import DD | |
6 | |
7 from tables import * | |
8 | |
9 from series import * | |
10 | |
11 | |
12 def compare_floats(f1,f2): | |
13 if f1-f2 < 1e-3: | |
14 return True | |
15 return False | |
16 | |
17 def compare_lists(it1, it2, floats=False): | |
18 if len(it1) != len(it2): | |
19 return False | |
20 | |
21 for el1, el2 in zip(it1, it2): | |
22 if floats: | |
23 if not compare_floats(el1,el2): | |
24 return False | |
25 elif el1 != el2: | |
26 return False | |
27 | |
28 return True | |
29 | |
30 def test_ErrorSeries_common_case(h5f=None): | |
31 if not h5f: | |
32 h5f_path = tempfile.NamedTemporaryFile().name | |
33 h5f = openFile(h5f_path, "w") | |
34 | |
35 validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error", | |
36 hdf5_file=h5f, index_names=('epoch','minibatch'), | |
37 title="Validation error indexed by epoch and minibatch") | |
38 | |
39 # (1,1), (1,2) etc. are (epoch, minibatch) index | |
40 validation_error.append((1,1), 32.0) | |
41 validation_error.append((1,2), 30.0) | |
42 validation_error.append((2,1), 28.0) | |
43 validation_error.append((2,2), 26.0) | |
44 | |
45 h5f.close() | |
46 | |
47 h5f = openFile(h5f_path, "r") | |
48 | |
49 table = h5f.getNode('/', 'validation_error') | |
50 | |
51 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) | |
52 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) | |
53 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0]) | |
54 | |
55 def test_AccumulatorSeriesWrapper_common_case(h5f=None): | |
56 if not h5f: | |
57 h5f_path = tempfile.NamedTemporaryFile().name | |
58 h5f = openFile(h5f_path, "w") | |
59 | |
60 validation_error = ErrorSeries(error_name="accumulated_validation_error", | |
61 table_name="accumulated_validation_error", | |
62 hdf5_file=h5f, | |
63 index_names=('epoch','minibatch'), | |
64 title="Validation error, summed every 3 minibatches, indexed by epoch and minibatch") | |
65 | |
66 accumulator = AccumulatorSeriesWrapper(base_series=validation_error, | |
67 reduce_every=3, reduce_function=numpy.sum) | |
68 | |
69 # (1,1), (1,2) etc. are (epoch, minibatch) index | |
70 accumulator.append((1,1), 32.0) | |
71 accumulator.append((1,2), 30.0) | |
72 accumulator.append((2,1), 28.0) | |
73 accumulator.append((2,2), 26.0) | |
74 accumulator.append((3,1), 24.0) | |
75 accumulator.append((3,2), 22.0) | |
76 | |
77 h5f.close() | |
78 | |
79 h5f = openFile(h5f_path, "r") | |
80 | |
81 table = h5f.getNode('/', 'accumulated_validation_error') | |
82 | |
83 assert compare_lists(table.cols.epoch[:], [2,3]) | |
84 assert compare_lists(table.cols.minibatch[:], [1,2]) | |
85 assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True) | |
86 | |
87 def test_BasicStatisticsSeries_common_case(h5f=None): | |
88 if not h5f: | |
89 h5f_path = tempfile.NamedTemporaryFile().name | |
90 h5f = openFile(h5f_path, "w") | |
91 | |
92 stats_series = BasicStatisticsSeries(table_name="b_vector_statistics", | |
93 hdf5_file=h5f, index_names=('epoch','minibatch'), | |
94 title="Basic statistics for b vector indexed by epoch and minibatch") | |
95 | |
96 # (1,1), (1,2) etc. are (epoch, minibatch) index | |
97 stats_series.append((1,1), [0.15, 0.20, 0.30]) | |
98 stats_series.append((1,2), [-0.18, 0.30, 0.58]) | |
99 stats_series.append((2,1), [0.18, -0.38, -0.68]) | |
100 stats_series.append((2,2), [0.15, 0.02, 1.9]) | |
101 | |
102 h5f.close() | |
103 | |
104 h5f = openFile(h5f_path, "r") | |
105 | |
106 table = h5f.getNode('/', 'b_vector_statistics') | |
107 | |
108 assert compare_lists(table.cols.epoch[:], [1,1,2,2]) | |
109 assert compare_lists(table.cols.minibatch[:], [1,2,1,2]) | |
110 assert compare_lists(table.cols.mean[:], [0.21666667, 0.23333333, -0.29333332, 0.69], floats=True) | |
111 assert compare_lists(table.cols.min[:], [0.15000001, -0.18000001, -0.68000001, 0.02], floats=True) | |
112 assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True) | |
113 assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939, 0.35640177, 0.85724366], floats=True) | |
114 | |
115 def test_SharedParamsStatisticsWrapper_commoncase(h5f=None): | |
116 import numpy.random | |
117 | |
118 if not h5f: | |
119 h5f_path = tempfile.NamedTemporaryFile().name | |
120 h5f = openFile(h5f_path, "w") | |
121 | |
122 stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/", | |
123 arrays_names=('b1','b2','b3'), hdf5_file=h5f, | |
124 index_names=('epoch','minibatch')) | |
125 | |
126 b1 = DD({'value':numpy.random.rand(5)}) | |
127 b2 = DD({'value':numpy.random.rand(5)}) | |
128 b3 = DD({'value':numpy.random.rand(5)}) | |
129 stats.append((1,1), [b1,b2,b3]) | |
130 | |
131 h5f.close() | |
132 | |
133 h5f = openFile(h5f_path, "r") | |
134 | |
135 b1_table = h5f.getNode('/params', 'b1') | |
136 b3_table = h5f.getNode('/params', 'b3') | |
137 | |
138 assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3 | |
139 assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3 | |
140 assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3 | |
141 assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3 | |
142 | |
143 def test_get_desc(): | |
144 h5f_path = tempfile.NamedTemporaryFile().name | |
145 h5f = openFile(h5f_path, "w") | |
146 | |
147 desc = get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4")) | |
148 | |
149 mytable = h5f.createTable('/', 'mytable', desc) | |
150 | |
151 # just make sure the columns are there... otherwise this will throw an exception | |
152 mytable.cols.col1 | |
153 mytable.cols.col2 | |
154 mytable.cols.col3 | |
155 mytable.cols.col4 | |
156 | |
157 try: | |
158 # this should fail... LocalDescription must be local to get_desc_etc | |
159 test = LocalDescription | |
160 assert False | |
161 except: | |
162 assert True | |
163 | |
164 assert True | |
165 | |
166 if __name__ == '__main__': | |
167 import tempfile | |
168 test_get_desc() | |
169 test_ErrorSeries_common_case() | |
170 test_BasicStatisticsSeries_common_case() | |
171 test_AccumulatorSeriesWrapper_common_case() | |
172 test_SharedParamsStatisticsWrapper_commoncase() | |
173 |