comparison utils/scalar_series/test_series.py @ 186:d364a130b221

Ajout du code de base pour scalar_series. Modifications à stacked_dae: réglé un problème avec les input_divider (empêchait une optimisation), et ajouté utilisation des séries. Si j'avais pas déjà commité, aussi, j'ai enlevé l'histoire de réutilisation du pretraining: c'était compliqué (error prone) et ça créait des jobs beaucoup trop longues.
author fsavard
date Mon, 01 Mar 2010 11:45:25 -0500
parents
children
comparison
equal deleted inserted replaced
185:b9ea8e2d071a 186:d364a130b221
1 #!/usr/bin/python
2 # coding: utf-8
3
4 import sys
5 import tempfile
6 import os.path
7 import os
8
9 import numpy
10
11 from series import BaseSeries, AccumulatorSeries, SeriesContainer, BasicStatsSeries, SeriesMultiplexer, SeriesList, ParamsArrayStats
12
13
14 BASEDIR = tempfile.mkdtemp()
15
16 def tempname():
17 file = tempfile.NamedTemporaryFile(dir=BASEDIR)
18 filepath = file.name
19 return os.path.split(filepath)
20
21 def tempdir():
22 wholepath = os.path.dirname(tempfile.mkdtemp(dir=BASEDIR))
23 # split again, interpreting the last directory as a filename
24 return os.path.split(wholepath)
25
26 def tempseries(type='f', flush_every=1):
27 dir, filename = tempname()
28
29 s = BaseSeries(name=filename, directory=dir, type=type, flush_every=flush_every)
30
31 return s
32
33 def test_Series_storeload():
34 s = tempseries()
35
36 s.append(12.0)
37 s.append_list([13.0,14.0,15.0])
38
39 s2 = BaseSeries(name=s.name, directory=s.directory, flush_every=15)
40 # also test if elements stored before load_from_file (and before a flush)
41 # are deleted (or array is restarted from scratch... both work)
42 s2.append(10.0)
43 s2.append_list([30.0,40.0])
44 s2.load_from_file()
45
46 assert s2.tolist() == [12.0,13.0,14.0,15.0]
47
48
49 def test_AccumulatorSeries_mean():
50 dir, filename = tempname()
51
52 s = AccumulatorSeries(reduce_every=15, mean=True, name=filename, directory=dir)
53
54 for i in range(50):
55 s.append(i)
56
57 assert s.tolist() == [7.0,22.0,37.0]
58
59 def test_BasicStatsSeries_commoncase():
60 a1 = numpy.arange(25).reshape((5,5))
61 a2 = numpy.arange(40).reshape((8,5))
62
63 parent_dir, dir = tempdir()
64
65 bss = BasicStatsSeries(parent_directory=parent_dir, name=dir)
66
67 bss.append(a1)
68 bss.append(a2)
69
70 assert bss.means.tolist() == [12.0, 19.5]
71 assert bss.mins.tolist() == [0.0, 0.0]
72 assert bss.maxes.tolist() == [24.0, 39.0]
73 assert (bss.stds.tolist()[0] - 7.211102) < 1e-3
74 assert (bss.stds.tolist()[1] - 11.54339) < 1e-3
75
76 # try to reload
77
78 bss2 = BasicStatsSeries(parent_directory=parent_dir, name=dir)
79 bss2.load_from_directory()
80
81 assert bss2.means.tolist() == [12.0, 19.5]
82 assert bss2.mins.tolist() == [0.0, 0.0]
83 assert bss2.maxes.tolist() == [24.0, 39.0]
84 assert (bss2.stds.tolist()[0] - 7.211102) < 1e-3
85 assert (bss2.stds.tolist()[1] - 11.54339) < 1e-3
86
87 def test_BasicStatsSeries_reload():
88 a1 = numpy.arange(25).reshape((5,5))
89 a2 = numpy.arange(40).reshape((8,5))
90
91 parent_dir, dir = tempdir()
92
93 bss = BasicStatsSeries(parent_directory=parent_dir, name=dir)
94
95 bss.append(a1)
96 bss.append(a2)
97
98 # try to reload
99
100 bss2 = BasicStatsSeries(parent_directory=parent_dir, name=dir)
101 bss2.load_from_directory()
102
103 assert bss2.means.tolist() == [12.0, 19.5]
104 assert bss2.mins.tolist() == [0.0, 0.0]
105 assert bss2.maxes.tolist() == [24.0, 39.0]
106 assert (bss2.stds.tolist()[0] - 7.211102) < 1e-3
107 assert (bss2.stds.tolist()[1] - 11.54339) < 1e-3
108
109
110 def test_BasicStatsSeries_withaccumulator():
111 a1 = numpy.arange(25).reshape((5,5))
112 a2 = numpy.arange(40).reshape((8,5))
113 a3 = numpy.arange(20).reshape((4,5))
114 a4 = numpy.arange(48).reshape((6,8))
115
116 parent_dir, dir = tempdir()
117
118 sc = AccumulatorSeries.series_constructor(reduce_every=2, mean=False)
119
120 bss = BasicStatsSeries(parent_directory=parent_dir, name=dir, series_constructor=sc)
121
122 bss.append(a1)
123 bss.append(a2)
124 bss.append(a3)
125 bss.append(a4)
126
127 assert bss.means.tolist() == [31.5, 33.0]
128
129 def test_SeriesList_withbasicstats():
130 dir = tempfile.mkdtemp(dir=BASEDIR)
131
132 bscstr = BasicStatsSeries.series_constructor()
133
134 slist = SeriesList(num_elements=5, name="foo", directory=dir, series_constructor=bscstr)
135
136 for i in range(10): # 10 elements in each list
137 curlist = []
138 for j in range(5): # 5 = num_elements, ie. number of list to append to
139 dist = numpy.arange(i*j, i*j+10)
140 curlist.append(dist)
141 slist.append(curlist)
142
143 slist2 = SeriesList(num_elements=5, name="foo", directory=dir, series_constructor=bscstr)
144
145 slist2.load_from_files()
146
147 l1 = slist2._subseries[0].means.tolist()
148 l2 = slist2._subseries[4].means.tolist()
149
150 print l1
151 print l2
152
153 assert l1 == [4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5]
154 assert l2 == [4.5, 8.5, 12.5, 16.5, 20.5, 24.5, 28.5, 32.5, 36.5, 40.5]
155
156 # same test as above, just with the shortcut
157 def test_ParamsArrayStats_reload():
158 dir = tempfile.mkdtemp(dir=BASEDIR)
159
160 slist = ParamsArrayStats(5, name="foo", directory=dir)
161
162 for i in range(10): # 10 elements in each list
163 curlist = []
164 for j in range(5): # 5 = num_elements, ie. number of list to append to
165 dist = numpy.arange(i*j, i*j+10)
166 curlist.append(dist)
167 slist.append(curlist)
168
169 slist2 = ParamsArrayStats(5, name="foo", directory=dir)
170
171 slist2.load_from_files()
172
173 l1 = slist2._subseries[0].means.tolist()
174 l2 = slist2._subseries[4].means.tolist()
175
176 print l1
177 print l2
178
179 assert l1 == [4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5]
180 assert l2 == [4.5, 8.5, 12.5, 16.5, 20.5, 24.5, 28.5, 32.5, 36.5, 40.5]
181
182
183 def manual_BasicStatsSeries_graph():
184 parent_dir, dir = tempdir()
185
186 bss = BasicStatsSeries(parent_directory=parent_dir, name=dir)
187
188 for i in range(50):
189 bss.append(1.0/numpy.arange(i*5, i*5+5))
190
191 bss.graph()
192
193 #if __name__ == '__main__':
194 # import pylab
195 # manual_BasicStatsSeries_graph()
196 # pylab.show()
197