Mercurial > ift6266
comparison utils/seriestables/series.py @ 220:e172ef73cdc5
Ajouté un paquet de type/value checks à SeriesTables, et finalisé les docstrings. Ajouté 3-4 tests. Légers refactorings ici et là sans conséquences externes.
author | fsavard |
---|---|
date | Thu, 11 Mar 2010 10:48:54 -0500 |
parents | 4c137f16b013 |
children | 02d9c1279dd8 |
comparison
equal
deleted
inserted
replaced
219:cde71d24f235 | 220:e172ef73cdc5 |
---|---|
1 from tables import * | 1 import tables |
2 | |
2 import numpy | 3 import numpy |
3 import time | 4 import time |
5 | |
6 ############################################################################## | |
7 # Utility functions to create IsDescription objects (pytables data types) | |
4 | 8 |
5 ''' | 9 ''' |
6 The way these "IsDescription constructor" work is simple: write the | 10 The way these "IsDescription constructor" work is simple: write the |
7 code as if it were in a file, then exec()ute it, leaving us with | 11 code as if it were in a file, then exec()ute it, leaving us with |
8 a local-scoped LocalDescription which may be used to call createTable. | 12 a local-scoped LocalDescription which may be used to call createTable. |
14 | 18 |
15 def _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock, pos=0): | 19 def _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock, pos=0): |
16 toexec = "" | 20 toexec = "" |
17 | 21 |
18 if store_timestamp: | 22 if store_timestamp: |
19 toexec += "\ttimestamp = Time32Col(pos="+str(pos)+")\n" | 23 toexec += "\ttimestamp = tables.Time32Col(pos="+str(pos)+")\n" |
20 pos += 1 | 24 pos += 1 |
21 | 25 |
22 if store_cpuclock: | 26 if store_cpuclock: |
23 toexec += "\tcpuclock = Float64Col(pos="+str(pos)+")\n" | 27 toexec += "\tcpuclock = tables.Float64Col(pos="+str(pos)+")\n" |
24 pos += 1 | 28 pos += 1 |
25 | 29 |
26 return toexec, pos | 30 return toexec, pos |
27 | 31 |
28 def _get_description_n_ints(int_names, int_width=64, pos=0): | 32 def _get_description_n_ints(int_names, int_width=64, pos=0): |
30 Begins construction of a class inheriting from IsDescription | 34 Begins construction of a class inheriting from IsDescription |
31 to construct an HDF5 table with index columns named with int_names. | 35 to construct an HDF5 table with index columns named with int_names. |
32 | 36 |
33 See Series().__init__ to see how those are used. | 37 See Series().__init__ to see how those are used. |
34 """ | 38 """ |
35 int_constructor = "Int64Col" | 39 int_constructor = "tables.Int64Col" |
36 if int_width == 32: | 40 if int_width == 32: |
37 int_constructor = "Int32Col" | 41 int_constructor = "tables.Int32Col" |
38 | 42 |
39 toexec = "" | 43 toexec = "" |
40 | 44 |
41 for n in int_names: | 45 for n in int_names: |
42 toexec += "\t" + n + " = " + int_constructor + "(pos=" + str(pos) + ")\n" | 46 toexec += "\t" + n + " = " + int_constructor + "(pos=" + str(pos) + ")\n" |
72 Returns | 76 Returns |
73 ------- | 77 ------- |
74 A class object, to pass to createTable() | 78 A class object, to pass to createTable() |
75 """ | 79 """ |
76 | 80 |
77 toexec = "class LocalDescription(IsDescription):\n" | 81 toexec = "class LocalDescription(tables.IsDescription):\n" |
78 | 82 |
79 toexec_, pos = _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock) | 83 toexec_, pos = _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock) |
80 toexec += toexec_ | 84 toexec += toexec_ |
81 | 85 |
82 toexec_, pos = _get_description_n_ints(int_names, int_width=int_width, pos=pos) | 86 toexec_, pos = _get_description_n_ints(int_names, int_width=int_width, pos=pos) |
83 toexec += toexec_ | 87 toexec += toexec_ |
84 | 88 |
85 float_constructor = "Float32Col" | 89 float_constructor = "tables.Float32Col" |
86 if float_width == 64: | 90 if float_width == 64: |
87 float_constructor = "Float64Col" | 91 float_constructor = "tables.Float64Col" |
88 | 92 |
89 for n in float_names: | 93 for n in float_names: |
90 toexec += "\t" + n + " = " + float_constructor + "(pos=" + str(pos) + ")\n" | 94 toexec += "\t" + n + " = " + float_constructor + "(pos=" + str(pos) + ")\n" |
91 pos += 1 | 95 pos += 1 |
92 | 96 |
93 exec(toexec) | 97 exec(toexec) |
94 | 98 |
95 return LocalDescription | 99 return LocalDescription |
96 | 100 |
101 ############################################################################## | |
102 # Series classes | |
103 | |
104 # Shortcut to allow passing a single int as index, instead of a tuple | |
105 def _index_to_tuple(index): | |
106 if type(index) == tuple: | |
107 return index | |
108 | |
109 if type(index) == list: | |
110 index = tuple(index) | |
111 return index | |
112 | |
113 try: | |
114 if index % 1 > 0.001 and index % 1 < 0.999: | |
115 raise | |
116 idx = long(index) | |
117 return (idx,) | |
118 except: | |
119 raise TypeError("index must be a tuple of integers, or at least a single integer") | |
120 | |
97 class Series(): | 121 class Series(): |
122 """ | |
123 Base Series class, with minimal arguments and type checks. | |
124 | |
125 Yet cannot be used by itself (it's append() method raises an error) | |
126 """ | |
127 | |
98 def __init__(self, table_name, hdf5_file, index_names=('epoch',), | 128 def __init__(self, table_name, hdf5_file, index_names=('epoch',), |
99 title="", hdf5_group='/', | 129 title="", hdf5_group='/', |
100 store_timestamp=True, store_cpuclock=True): | 130 store_timestamp=True, store_cpuclock=True): |
101 """Basic arguments each Series must get. | 131 """Basic arguments each Series must get. |
102 | 132 |
103 Parameters | 133 Parameters |
104 ---------- | 134 ---------- |
105 table_name : str | 135 table_name : str |
106 Name of the table to create under group "hd5_group" (other parameter). No spaces, ie. follow variable naming restrictions. | 136 Name of the table to create under group "hd5_group" (other |
137 parameter). No spaces, ie. follow variable naming restrictions. | |
107 hdf5_file : open HDF5 file | 138 hdf5_file : open HDF5 file |
108 File opened with openFile() in PyTables (ie. return value of openFile). | 139 File opened with openFile() in PyTables (ie. return value of |
140 openFile). | |
109 index_names : tuple of str | 141 index_names : tuple of str |
110 Columns to use as index for elements in the series, other example would be ('epoch', 'minibatch'). This would then allow you to call append(index, element) with index made of two ints, one for epoch index, one for minibatch index in epoch. | 142 Columns to use as index for elements in the series, other |
143 example would be ('epoch', 'minibatch'). This would then allow | |
144 you to call append(index, element) with index made of two ints, | |
145 one for epoch index, one for minibatch index in epoch. | |
111 title : str | 146 title : str |
112 Title to attach to this table as metadata. Can contain spaces and be longer then the table_name. | 147 Title to attach to this table as metadata. Can contain spaces |
148 and be longer then the table_name. | |
113 hdf5_group : str | 149 hdf5_group : str |
114 Path of the group (kind of a file) in the HDF5 file under which to create the table. | 150 Path of the group (kind of a file) in the HDF5 file under which |
151 to create the table. | |
115 store_timestamp : bool | 152 store_timestamp : bool |
116 Whether to create a column for timestamps and store them with each record. | 153 Whether to create a column for timestamps and store them with |
154 each record. | |
117 store_cpuclock : bool | 155 store_cpuclock : bool |
118 Whether to create a column for cpu clock and store it with each record. | 156 Whether to create a column for cpu clock and store it with |
119 """ | 157 each record. |
158 """ | |
159 | |
160 ######################################### | |
161 # checks | |
162 | |
163 if type(table_name) != str: | |
164 raise TypeError("table_name must be a string") | |
165 if table_name == "": | |
166 raise ValueError("table_name must not be empty") | |
167 | |
168 if not isinstance(hdf5_file, tables.file.File): | |
169 raise TypeError("hdf5_file must be an open HDF5 file (use tables.openFile)") | |
170 #if not ('w' in hdf5_file.mode or 'a' in hdf5_file.mode): | |
171 # raise ValueError("hdf5_file must be opened in write or append mode") | |
172 | |
173 if type(index_names) != tuple: | |
174 raise TypeError("index_names must be a tuple of strings." + \ | |
175 "If you have only one element in the tuple, don't forget " +\ | |
176 "to add a comma, e.g. ('epoch',).") | |
177 for name in index_names: | |
178 if type(name) != str: | |
179 raise TypeError("index_names must only contain strings, but also"+\ | |
180 "contains a "+str(type(name))+".") | |
181 | |
182 if type(title) != str: | |
183 raise TypeError("title must be a string, even if empty") | |
184 | |
185 if type(hdf5_group) != str: | |
186 raise TypeError("hdf5_group must be a string") | |
187 | |
188 if type(store_timestamp) != bool: | |
189 raise TypeError("store_timestamp must be a bool") | |
190 | |
191 if type(store_cpuclock) != bool: | |
192 raise TypeError("store_timestamp must be a bool") | |
193 | |
194 ######################################### | |
195 | |
120 self.table_name = table_name | 196 self.table_name = table_name |
121 self.hdf5_file = hdf5_file | 197 self.hdf5_file = hdf5_file |
122 self.index_names = index_names | 198 self.index_names = index_names |
123 self.title = title | 199 self.title = title |
200 self.hdf5_group = hdf5_group | |
124 | 201 |
125 self.store_timestamp = store_timestamp | 202 self.store_timestamp = store_timestamp |
126 self.store_cpuclock = store_cpuclock | 203 self.store_cpuclock = store_cpuclock |
127 | 204 |
128 def append(self, index, element): | 205 def append(self, index, element): |
130 | 207 |
131 def _timestamp_cpuclock(self, newrow): | 208 def _timestamp_cpuclock(self, newrow): |
132 newrow["timestamp"] = time.time() | 209 newrow["timestamp"] = time.time() |
133 newrow["cpuclock"] = time.clock() | 210 newrow["cpuclock"] = time.clock() |
134 | 211 |
135 # To put in a series dictionary instead of a real series, to do nothing | |
136 # when we don't want a given series to be saved. | |
137 class DummySeries(): | 212 class DummySeries(): |
213 """ | |
214 To put in a series dictionary instead of a real series, to do nothing | |
215 when we don't want a given series to be saved. | |
216 | |
217 E.g. if we'd normally have a "training_error" series in a dictionary | |
218 of series, the training loop would have something like this somewhere: | |
219 | |
220 series["training_error"].append((15,), 20.0) | |
221 | |
222 but if we don't want to save the training errors this time, we simply | |
223 do | |
224 | |
225 series["training_error"] = DummySeries() | |
226 """ | |
138 def append(self, index, element): | 227 def append(self, index, element): |
139 pass | 228 pass |
140 | 229 |
141 class ErrorSeries(Series): | 230 class ErrorSeries(Series): |
231 """ | |
232 Most basic Series: saves a single float (called an Error as this is | |
233 the most common use case I foresee) along with an index (epoch, for | |
234 example) and timestamp/cpu.clock for each of these floats. | |
235 """ | |
236 | |
142 def __init__(self, error_name, table_name, | 237 def __init__(self, error_name, table_name, |
143 hdf5_file, index_names=('epoch',), | 238 hdf5_file, index_names=('epoch',), |
144 title="", hdf5_group='/', | 239 title="", hdf5_group='/', |
145 store_timestamp=True, store_cpuclock=True): | 240 store_timestamp=True, store_cpuclock=True): |
146 Series.__init__(self, table_name, hdf5_file, index_names, title, store_timestamp, store_cpuclock) | 241 """ |
242 For most parameters, see Series.__init__ | |
243 | |
244 Parameters | |
245 ---------- | |
246 error_name : str | |
247 In the HDF5 table, column name for the error float itself. | |
248 """ | |
249 | |
250 # most type/value checks are performed in Series.__init__ | |
251 Series.__init__(self, table_name, hdf5_file, index_names, title, | |
252 store_timestamp=store_timestamp, | |
253 store_cpuclock=store_cpuclock) | |
254 | |
255 if type(error_name) != str: | |
256 raise TypeError("error_name must be a string") | |
257 if error_name == "": | |
258 raise ValueError("error_name must not be empty") | |
147 | 259 |
148 self.error_name = error_name | 260 self.error_name = error_name |
149 | 261 |
150 table_description = self._get_table_description() | 262 self._create_table() |
151 | 263 |
152 self._table = hdf5_file.createTable(hdf5_group, self.table_name, table_description, title=title) | 264 def _create_table(self): |
153 | 265 table_description = _get_description_with_n_ints_n_floats( \ |
154 def _get_table_description(self): | 266 self.index_names, (self.error_name,)) |
155 return _get_description_with_n_ints_n_floats(self.index_names, (self.error_name,)) | 267 |
268 self._table = self.hdf5_file.createTable(self.hdf5_group, | |
269 self.table_name, | |
270 table_description, | |
271 title=self.title) | |
272 | |
156 | 273 |
157 def append(self, index, error): | 274 def append(self, index, error): |
158 """ | 275 """ |
159 Parameters | 276 Parameters |
160 ---------- | 277 ---------- |
161 index : tuple of int | 278 index : tuple of int |
162 Following index_names passed to __init__, e.g. (12, 15) if index_names were ('epoch', 'minibatch_size') | 279 Following index_names passed to __init__, e.g. (12, 15) if |
280 index_names were ('epoch', 'minibatch_size'). | |
281 A single int (not tuple) is acceptable if index_names has a single | |
282 element. | |
283 An array will be casted to a tuple, as a convenience. | |
284 | |
163 error : float | 285 error : float |
164 Next error in the series. | 286 Next error in the series. |
165 """ | 287 """ |
288 index = _index_to_tuple(index) | |
289 | |
166 if len(index) != len(self.index_names): | 290 if len(index) != len(self.index_names): |
167 raise ValueError("index provided does not have the right length (expected " \ | 291 raise ValueError("index provided does not have the right length (expected " \ |
168 + str(len(self.index_names)) + " got " + str(len(index))) | 292 + str(len(self.index_names)) + " got " + str(len(index))) |
293 | |
294 # other checks are implicit when calling newrow[..] =, | |
295 # which should throw an error if not of the right type | |
169 | 296 |
170 newrow = self._table.row | 297 newrow = self._table.row |
171 | 298 |
172 # Columns for index in table are based on index_names | 299 # Columns for index in table are based on index_names |
173 for col_name, value in zip(self.index_names, index): | 300 for col_name, value in zip(self.index_names, index): |
174 newrow[col_name] = value | 301 newrow[col_name] = value |
175 newrow[self.error_name] = error | 302 newrow[self.error_name] = error |
176 | 303 |
304 # adds timestamp and cpuclock to newrow if necessary | |
177 self._timestamp_cpuclock(newrow) | 305 self._timestamp_cpuclock(newrow) |
178 | 306 |
179 newrow.append() | 307 newrow.append() |
180 | 308 |
181 self.hdf5_file.flush() | 309 self.hdf5_file.flush() |
182 | 310 |
183 # Does not inherit from Series because it does not itself need to | 311 # Does not inherit from Series because it does not itself need to |
184 # access the hdf5_file and does not need a series_name (provided | 312 # access the hdf5_file and does not need a series_name (provided |
185 # by the base_series.) | 313 # by the base_series.) |
186 class AccumulatorSeriesWrapper(): | 314 class AccumulatorSeriesWrapper(): |
187 """ | 315 ''' |
188 | 316 Wraps a Series by accumulating objects passed its Accumulator.append() |
189 """ | 317 method and "reducing" (e.g. calling numpy.mean(list)) once in a while, |
318 every "reduce_every" calls in fact. | |
319 ''' | |
320 | |
190 def __init__(self, base_series, reduce_every, reduce_function=numpy.mean): | 321 def __init__(self, base_series, reduce_every, reduce_function=numpy.mean): |
191 """ | 322 """ |
192 Parameters | 323 Parameters |
193 ---------- | 324 ---------- |
194 base_series : Series | 325 base_series : Series |
195 This object must have an append(index, value) function. | 326 This object must have an append(index, value) function. |
327 | |
196 reduce_every : int | 328 reduce_every : int |
197 Apply the reduction function (e.g. mean()) every time we get this number of elements. E.g. if this is 100, then every 100 numbers passed to append(), we'll take the mean and call append(this_mean) on the BaseSeries. | 329 Apply the reduction function (e.g. mean()) every time we get this |
330 number of elements. E.g. if this is 100, then every 100 numbers | |
331 passed to append(), we'll take the mean and call append(this_mean) | |
332 on the BaseSeries. | |
333 | |
198 reduce_function : function | 334 reduce_function : function |
199 Must take as input an array of "elements", as passed to (this accumulator's) append(). Basic case would be to take an array of floats and sum them into one float, for example. | 335 Must take as input an array of "elements", as passed to (this |
336 accumulator's) append(). Basic case would be to take an array of | |
337 floats and sum them into one float, for example. | |
200 """ | 338 """ |
201 self.base_series = base_series | 339 self.base_series = base_series |
202 self.reduce_function = reduce_function | 340 self.reduce_function = reduce_function |
203 self.reduce_every = reduce_every | 341 self.reduce_every = reduce_every |
204 | 342 |
211 ---------- | 349 ---------- |
212 index : tuple of int | 350 index : tuple of int |
213 The index used is the one of the last element reduced. E.g. if | 351 The index used is the one of the last element reduced. E.g. if |
214 you accumulate over the first 1000 minibatches, the index | 352 you accumulate over the first 1000 minibatches, the index |
215 passed to the base_series.append() function will be 1000. | 353 passed to the base_series.append() function will be 1000. |
354 A single int (not tuple) is acceptable if index_names has a single | |
355 element. | |
356 An array will be casted to a tuple, as a convenience. | |
357 | |
216 element : float | 358 element : float |
217 Element that will be accumulated. | 359 Element that will be accumulated. |
218 """ | 360 """ |
219 self._buffer.append(element) | 361 self._buffer.append(element) |
220 | 362 |
221 if len(self._buffer) == self.reduce_every: | 363 if len(self._buffer) == self.reduce_every: |
222 reduced = self.reduce_function(self._buffer) | 364 reduced = self.reduce_function(self._buffer) |
223 self.base_series.append(index, reduced) | 365 self.base_series.append(index, reduced) |
224 self._buffer = [] | 366 self._buffer = [] |
225 | 367 |
226 # This should never happen, except if lists | 368 # The >= case should never happen, except if lists |
227 # were appended, which should be a red flag. | 369 # were appended, which should be a red flag. |
228 assert len(self._buffer) < self.reduce_every | 370 assert len(self._buffer) < self.reduce_every |
229 | 371 |
230 # Outside of class to fix an issue with exec in Python 2.6. | 372 # Outside of class to fix an issue with exec in Python 2.6. |
231 # My sorries to the God of pretty code. | 373 # My sorries to the god of pretty code. |
232 def _BasicStatisticsSeries_construct_table_toexec(index_names, store_timestamp, store_cpuclock): | 374 def _BasicStatisticsSeries_construct_table_toexec(index_names, store_timestamp, store_cpuclock): |
233 toexec = "class LocalDescription(IsDescription):\n" | 375 toexec = "class LocalDescription(tables.IsDescription):\n" |
234 | 376 |
235 toexec_, pos = _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock) | 377 toexec_, pos = _get_description_timestamp_cpuclock_columns(store_timestamp, store_cpuclock) |
236 toexec += toexec_ | 378 toexec += toexec_ |
237 | 379 |
238 toexec_, pos = _get_description_n_ints(index_names, pos=pos) | 380 toexec_, pos = _get_description_n_ints(index_names, pos=pos) |
239 toexec += toexec_ | 381 toexec += toexec_ |
240 | 382 |
241 toexec += "\tmean = Float32Col(pos=" + str(pos) + ")\n" | 383 toexec += "\tmean = tables.Float32Col(pos=" + str(pos) + ")\n" |
242 toexec += "\tmin = Float32Col(pos=" + str(pos+1) + ")\n" | 384 toexec += "\tmin = tables.Float32Col(pos=" + str(pos+1) + ")\n" |
243 toexec += "\tmax = Float32Col(pos=" + str(pos+2) + ")\n" | 385 toexec += "\tmax = tables.Float32Col(pos=" + str(pos+2) + ")\n" |
244 toexec += "\tstd = Float32Col(pos=" + str(pos+3) + ")\n" | 386 toexec += "\tstd = tables.Float32Col(pos=" + str(pos+3) + ")\n" |
245 | 387 |
246 # This creates "LocalDescription", which we may then use | 388 # This creates "LocalDescription", which we may then use |
247 exec(toexec) | 389 exec(toexec) |
248 | 390 |
249 return LocalDescription | 391 return LocalDescription |
250 | 392 |
251 basic_stats_functions = {'mean': lambda(x): numpy.mean(x), | 393 # Defaults functions for BasicStatsSeries. These can be replaced. |
394 _basic_stats_functions = {'mean': lambda(x): numpy.mean(x), | |
252 'min': lambda(x): numpy.min(x), | 395 'min': lambda(x): numpy.min(x), |
253 'max': lambda(x): numpy.max(x), | 396 'max': lambda(x): numpy.max(x), |
254 'std': lambda(x): numpy.std(x)} | 397 'std': lambda(x): numpy.std(x)} |
255 | 398 |
256 class BasicStatisticsSeries(Series): | 399 class BasicStatisticsSeries(Series): |
257 """ | 400 |
258 Parameters | |
259 ---------- | |
260 series_name : str | |
261 Not optional here. Will be prepended with "Basic statistics for " | |
262 stats_functions : dict, optional | |
263 Dictionary with a function for each key "mean", "min", "max", "std". The function must take whatever is passed to append(...) and return a single number (float). | |
264 """ | |
265 def __init__(self, table_name, hdf5_file, | 401 def __init__(self, table_name, hdf5_file, |
266 stats_functions=basic_stats_functions, | 402 stats_functions=_basic_stats_functions, |
267 index_names=('epoch',), title="", hdf5_group='/', | 403 index_names=('epoch',), title="", hdf5_group='/', |
268 store_timestamp=True, store_cpuclock=True): | 404 store_timestamp=True, store_cpuclock=True): |
269 Series.__init__(self, table_name, hdf5_file, index_names, title, store_timestamp, store_cpuclock) | 405 """ |
406 For most parameters, see Series.__init__ | |
407 | |
408 Parameters | |
409 ---------- | |
410 series_name : str | |
411 Not optional here. Will be prepended with "Basic statistics for " | |
412 | |
413 stats_functions : dict, optional | |
414 Dictionary with a function for each key "mean", "min", "max", | |
415 "std". The function must take whatever is passed to append(...) | |
416 and return a single number (float). | |
417 """ | |
418 | |
419 # Most type/value checks performed in Series.__init__ | |
420 Series.__init__(self, table_name, hdf5_file, index_names, title, | |
421 store_timestamp=store_timestamp, | |
422 store_cpuclock=store_cpuclock) | |
423 | |
424 if type(hdf5_group) != str: | |
425 raise TypeError("hdf5_group must be a string") | |
426 | |
427 if type(stats_functions) != dict: | |
428 # just a basic check. We'll suppose caller knows what he's doing. | |
429 raise TypeError("stats_functions must be a dict") | |
270 | 430 |
271 self.hdf5_group = hdf5_group | 431 self.hdf5_group = hdf5_group |
272 | 432 |
273 self.stats_functions = stats_functions | 433 self.stats_functions = stats_functions |
274 | 434 |
275 self._construct_table() | 435 self._create_table() |
276 | 436 |
277 def _construct_table(self): | 437 def _create_table(self): |
278 table_description = _BasicStatisticsSeries_construct_table_toexec(self.index_names, self.store_timestamp, self.store_cpuclock) | 438 table_description = \ |
279 | 439 _BasicStatisticsSeries_construct_table_toexec( \ |
280 self._table = self.hdf5_file.createTable(self.hdf5_group, self.table_name, table_description) | 440 self.index_names, |
441 self.store_timestamp, self.store_cpuclock) | |
442 | |
443 self._table = self.hdf5_file.createTable(self.hdf5_group, | |
444 self.table_name, table_description) | |
281 | 445 |
282 def append(self, index, array): | 446 def append(self, index, array): |
283 """ | 447 """ |
284 Parameters | 448 Parameters |
285 ---------- | 449 ---------- |
286 index : tuple of int | 450 index : tuple of int |
287 Following index_names passed to __init__, e.g. (12, 15) if index_names were ('epoch', 'minibatch_size') | 451 Following index_names passed to __init__, e.g. (12, 15) |
452 if index_names were ('epoch', 'minibatch_size') | |
453 A single int (not tuple) is acceptable if index_names has a single | |
454 element. | |
455 An array will be casted to a tuple, as a convenience. | |
456 | |
288 array | 457 array |
289 Is of whatever type the stats_functions passed to __init__ can take. Default is anything numpy.mean(), min(), max(), std() can take. | 458 Is of whatever type the stats_functions passed to |
290 """ | 459 __init__ can take. Default is anything numpy.mean(), |
460 min(), max(), std() can take. | |
461 """ | |
462 index = _index_to_tuple(index) | |
463 | |
291 if len(index) != len(self.index_names): | 464 if len(index) != len(self.index_names): |
292 raise ValueError("index provided does not have the right length (expected " \ | 465 raise ValueError("index provided does not have the right length (expected " \ |
293 + str(len(self.index_names)) + " got " + str(len(index))) | 466 + str(len(self.index_names)) + " got " + str(len(index))) |
294 | 467 |
295 newrow = self._table.row | 468 newrow = self._table.row |
308 | 481 |
309 self.hdf5_file.flush() | 482 self.hdf5_file.flush() |
310 | 483 |
311 class SeriesArrayWrapper(): | 484 class SeriesArrayWrapper(): |
312 """ | 485 """ |
313 Simply redistributes any number of elements to sub-series to respective append()s. | 486 Simply redistributes any number of elements to sub-series to respective |
314 | 487 append()s. |
315 To use if you have many elements to append in similar series, e.g. if you have an array containing [train_error, valid_error, test_error], and 3 corresponding series, this allows you to simply pass this array of 3 values to append() instead of passing each element to each individual series in turn. | 488 |
489 To use if you have many elements to append in similar series, e.g. if you | |
490 have an array containing [train_error, valid_error, test_error], and 3 | |
491 corresponding series, this allows you to simply pass this array of 3 | |
492 values to append() instead of passing each element to each individual | |
493 series in turn. | |
316 """ | 494 """ |
317 | 495 |
318 def __init__(self, base_series_list): | 496 def __init__(self, base_series_list): |
497 """ | |
498 Parameters | |
499 ---------- | |
500 base_series_list : array or tuple of Series | |
501 You must have previously created and configured each of those | |
502 series, then put them in an array. This array must follow the | |
503 same order as the array passed as ``elements`` parameter of | |
504 append(). | |
505 """ | |
319 self.base_series_list = base_series_list | 506 self.base_series_list = base_series_list |
320 | 507 |
321 def append(self, index, elements): | 508 def append(self, index, elements): |
509 """ | |
510 Parameters | |
511 ---------- | |
512 index : tuple of int | |
513 See for example ErrorSeries.append() | |
514 | |
515 elements : array or tuple | |
516 Array or tuple of elements that will be passed down to | |
517 the base_series passed to __init__, in the same order. | |
518 """ | |
322 if len(elements) != len(self.base_series_list): | 519 if len(elements) != len(self.base_series_list): |
323 raise ValueError("not enough or too much elements provided (expected " \ | 520 raise ValueError("not enough or too much elements provided (expected " \ |
324 + str(len(self.base_series_list)) + " got " + str(len(elements))) | 521 + str(len(self.base_series_list)) + " got " + str(len(elements))) |
325 | 522 |
326 for series, el in zip(self.base_series_list, elements): | 523 for series, el in zip(self.base_series_list, elements): |
327 series.append(index, el) | 524 series.append(index, el) |
328 | 525 |
329 class SharedParamsStatisticsWrapper(SeriesArrayWrapper): | 526 class SharedParamsStatisticsWrapper(SeriesArrayWrapper): |
330 '''Save mean, min/max, std of shared parameters place in an array. | 527 ''' |
331 | 528 Save mean, min/max, std of shared parameters place in an array. |
332 This is specifically for cases where we have _shared_ parameters, | 529 |
333 as we take the .value of each array''' | 530 Here "shared" means "theano.shared", which means elements of the |
334 | 531 array will have a .value to use for numpy.mean(), etc. |
335 def __init__(self, arrays_names, new_group_name, hdf5_file, base_group='/', index_names=('epoch',), title=""): | 532 |
336 """ | 533 This inherits from SeriesArrayWrapper, which provides the append() |
337 Parameters | 534 method. |
338 ---------- | 535 ''' |
339 array_names : array of str | 536 |
340 Name of each array, in order of the array passed to append(). E.g. ('layer1_b', 'layer1_W', 'layer2_b', 'layer2_W') | 537 def __init__(self, arrays_names, new_group_name, hdf5_file, |
538 base_group='/', index_names=('epoch',), title=""): | |
539 """ | |
540 For other parameters, see Series.__init__ | |
541 | |
542 Parameters | |
543 ---------- | |
544 array_names : array or tuple of str | |
545 Name of each array, in order of the array passed to append(). E.g. | |
546 ('layer1_b', 'layer1_W', 'layer2_b', 'layer2_W') | |
547 | |
341 new_group_name : str | 548 new_group_name : str |
342 Name of a new HDF5 group which will be created under base_group to store the new series. | 549 Name of a new HDF5 group which will be created under base_group to |
550 store the new series. | |
551 | |
343 base_group : str | 552 base_group : str |
344 Path of the group under which to create the new group which will store the series. | 553 Path of the group under which to create the new group which will |
554 store the series. | |
555 | |
345 title : str | 556 title : str |
346 Here the title is attached to the new group, not a table. | 557 Here the title is attached to the new group, not a table. |
347 """ | 558 """ |
559 | |
560 # most other checks done when calling BasicStatisticsSeries | |
561 if type(new_group_name) != str: | |
562 raise TypeError("new_group_name must be a string") | |
563 if new_group_name == "": | |
564 raise ValueError("new_group_name must not be empty") | |
565 | |
348 base_series_list = [] | 566 base_series_list = [] |
349 | 567 |
350 new_group = hdf5_file.createGroup(base_group, new_group_name, title=title) | 568 new_group = hdf5_file.createGroup(base_group, new_group_name, title=title) |
351 | 569 |
352 stats_functions = {'mean': lambda(x): numpy.mean(x.value), | 570 stats_functions = {'mean': lambda(x): numpy.mean(x.value), |