comparison utils/tables_series/test_series.py @ 208:acb942530923

Completely rewrote my series module, now based on HDF5 and PyTables (in a separate directory called 'tables_series' for retrocompatibility of running code). Minor (inconsequential) changes to stacked_dae.
author fsavard
date Fri, 05 Mar 2010 18:07:20 -0500
parents
children dc0d77c8a878
comparison
equal deleted inserted replaced
205:10a801240bfc 208:acb942530923
1 import tempfile
2 import numpy
3 import numpy.random
4 from tables import *
5
6 from series import *
7
8
9 def compare_floats(f1,f2):
10 if f1-f2 < 1e-3:
11 return True
12 return False
13
14 def compare_lists(it1, it2, floats=False):
15 if len(it1) != len(it2):
16 return False
17
18 for el1, el2 in zip(it1, it2):
19 if floats:
20 if not compare_floats(el1,el2):
21 return False
22 elif el1 != el2:
23 return False
24
25 return True
26
27 def test_ErrorSeries_common_case(h5f=None):
28 if not h5f:
29 h5f_path = tempfile.NamedTemporaryFile().name
30 h5f = openFile(h5f_path, "w")
31
32 validation_error = ErrorSeries(error_name="validation_error", table_name="validation_error",
33 hdf5_file=h5f, index_names=('epoch','minibatch'),
34 title="Validation error indexed by epoch and minibatch")
35
36 # (1,1), (1,2) etc. are (epoch, minibatch) index
37 validation_error.append((1,1), 32.0)
38 validation_error.append((1,2), 30.0)
39 validation_error.append((2,1), 28.0)
40 validation_error.append((2,2), 26.0)
41
42 h5f.close()
43
44 h5f = openFile(h5f_path, "r")
45
46 table = h5f.getNode('/', 'validation_error')
47
48 assert compare_lists(table.cols.epoch[:], [1,1,2,2])
49 assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
50 assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
51
52 def test_AccumulatorSeriesWrapper_common_case(h5f=None):
53 if not h5f:
54 h5f_path = tempfile.NamedTemporaryFile().name
55 h5f = openFile(h5f_path, "w")
56
57 validation_error = ErrorSeries(error_name="accumulated_validation_error",
58 table_name="accumulated_validation_error",
59 hdf5_file=h5f,
60 index_names=('epoch','minibatch'),
61 title="Validation error, summed every 3 minibatches, indexed by epoch and minibatch")
62
63 accumulator = AccumulatorSeriesWrapper(base_series=validation_error,
64 reduce_every=3, reduce_function=numpy.sum)
65
66 # (1,1), (1,2) etc. are (epoch, minibatch) index
67 accumulator.append((1,1), 32.0)
68 accumulator.append((1,2), 30.0)
69 accumulator.append((2,1), 28.0)
70 accumulator.append((2,2), 26.0)
71 accumulator.append((3,1), 24.0)
72 accumulator.append((3,2), 22.0)
73
74 h5f.close()
75
76 h5f = openFile(h5f_path, "r")
77
78 table = h5f.getNode('/', 'accumulated_validation_error')
79
80 assert compare_lists(table.cols.epoch[:], [2,3])
81 assert compare_lists(table.cols.minibatch[:], [1,2])
82 assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True)
83
84 def test_BasicStatisticsSeries_common_case(h5f=None):
85 if not h5f:
86 h5f_path = tempfile.NamedTemporaryFile().name
87 h5f = openFile(h5f_path, "w")
88
89 stats_series = BasicStatisticsSeries(table_name="b_vector_statistics",
90 hdf5_file=h5f, index_names=('epoch','minibatch'),
91 title="Basic statistics for b vector indexed by epoch and minibatch")
92
93 # (1,1), (1,2) etc. are (epoch, minibatch) index
94 stats_series.append((1,1), [0.15, 0.20, 0.30])
95 stats_series.append((1,2), [-0.18, 0.30, 0.58])
96 stats_series.append((2,1), [0.18, -0.38, -0.68])
97 stats_series.append((2,2), [0.15, 0.02, 1.9])
98
99 h5f.close()
100
101 h5f = openFile(h5f_path, "r")
102
103 table = h5f.getNode('/', 'b_vector_statistics')
104
105 assert compare_lists(table.cols.epoch[:], [1,1,2,2])
106 assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
107 assert compare_lists(table.cols.mean[:], [0.21666667, 0.23333333, -0.29333332, 0.69], floats=True)
108 assert compare_lists(table.cols.min[:], [0.15000001, -0.18000001, -0.68000001, 0.02], floats=True)
109 assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True)
110 assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939, 0.35640177, 0.85724366], floats=True)
111
112 def test_ParamsStatisticsWrapper_commoncase(h5f=None):
113 import numpy.random
114
115 if not h5f:
116 h5f_path = tempfile.NamedTemporaryFile().name
117 h5f = openFile(h5f_path, "w")
118
119 stats = ParamsStatisticsWrapper(new_group_name="params", base_group="/",
120 arrays_names=('b1','b2','b3'), hdf5_file=h5f,
121 index_names=('epoch','minibatch'))
122
123 b1 = numpy.random.rand(5)
124 b2 = numpy.random.rand(5)
125 b3 = numpy.random.rand(5)
126 stats.append((1,1), [b1,b2,b3])
127
128 h5f.close()
129
130 h5f = openFile(h5f_path, "r")
131
132 b1_table = h5f.getNode('/params', 'b1')
133 b3_table = h5f.getNode('/params', 'b3')
134
135 assert b1_table.cols.mean[0] - numpy.mean(b1) < 1e-3
136 assert b3_table.cols.mean[0] - numpy.mean(b3) < 1e-3
137 assert b1_table.cols.min[0] - numpy.min(b1) < 1e-3
138 assert b3_table.cols.min[0] - numpy.min(b3) < 1e-3
139
140 def test_get_desc():
141 h5f_path = tempfile.NamedTemporaryFile().name
142 h5f = openFile(h5f_path, "w")
143
144 desc = get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4"))
145
146 mytable = h5f.createTable('/', 'mytable', desc)
147
148 # just make sure the columns are there... otherwise this will throw an exception
149 mytable.cols.col1
150 mytable.cols.col2
151 mytable.cols.col3
152 mytable.cols.col4
153
154 try:
155 # this should fail... LocalDescription must be local to get_desc_etc
156 test = LocalDescription
157 assert False
158 except:
159 assert True
160
161 assert True
162
163 if __name__ == '__main__':
164 import tempfile
165 test_get_desc()
166 test_ErrorSeries_common_case()
167 test_BasicStatisticsSeries_common_case()
168 test_AccumulatorSeriesWrapper_common_case()
169 test_ParamsStatisticsWrapper_commoncase()
170