Mercurial > pylearn
comparison doc/v2_planning/arch_src/checkpoint_JB.py @ 1336:09ad2a4f663c
adding new idea to arch_src
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 18 Oct 2010 19:31:17 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
1335:7c51c0355d86 | 1336:09ad2a4f663c |
---|---|
1 import copy as copy_module | |
2 | |
3 #TODO: use logging.info to report cache hits/ misses | |
4 | |
5 CO_VARARGS = 0x0004 | |
6 CO_VARKEYWORDS = 0x0008 | |
7 | |
8 class mem_db(dict): | |
9 # A key->document dictionary. | |
10 # A "document" is itself a dictionary. | |
11 | |
12 # A "document" can be a small or large object, but it cannot be partially retrieved. | |
13 | |
14 # This simple data structure is used in pylearn to cache intermediate reults between | |
15 # several process invocations. | |
16 pass | |
17 | |
18 class UNSPECIFIED(object): | |
19 pass | |
20 | |
21 class CtrlObj(object): | |
22 """ | |
23 Job control API. | |
24 | |
25 This interface makes it easier to break a logical program into pieces that can be | |
26 executed by several different processes, either serially or in parallel. | |
27 | |
28 | |
29 The base class provides decorators to simplify some common cache patterns: | |
30 - cache_pickle to cache arbitrary return values using the pickle mechanism | |
31 | |
32 - cache_dict to cache dict return values directly using the document db | |
33 | |
34 - cache_numpy to cache [single] numpy ndarray rvals in a way that supports memmapping of | |
35 large arrays. | |
36 | |
37 Authors are encouraged to use these when they apply, but should feel free to implement | |
38 other cache logic when these standard ones are lacking using the CtorlObj.get() and | |
39 CtorlObj.set() methods. | |
40 | |
41 | |
42 """ | |
43 | |
44 def __init__(self, rootdir, db, autosync): | |
45 self.db = db | |
46 self.r_lookup = {} | |
47 self.autosync=autosync | |
48 | |
49 def get(self, key, default_val=UNSPECIFIED, copy=True): | |
50 # Default to return a COPY because a self.set() is required to make a change persistent. | |
51 # Inplace changes that the CtrlObj does not know about (via self.set()) will not be saved. | |
52 try: | |
53 val = self.db[key] | |
54 except: | |
55 if default_val is not UNSPECIFIED: | |
56 # return default_val, but do not add it to the r_lookup object | |
57 # since looking up that key in the future would not retrieve default_val | |
58 return default_val | |
59 else: | |
60 raise | |
61 if copy: | |
62 rval = copy_module.deepcopy(val) | |
63 else: | |
64 rval = val | |
65 self.r_lookup[id(rval)] = key | |
66 return rval | |
67 | |
68 def get_key(self, val): | |
69 """Return the key that retrieved `val`. | |
70 | |
71 This is useful for specifying cache keys for unhashable (e.g. numpy) objects that | |
72 happen to be stored in the db. | |
73 """ | |
74 return self.r_lookup[id(val)] | |
75 def set(self, key, val): | |
76 vv = dict(val) | |
77 if self.db.get(key, None) not in (val, None): | |
78 del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key] | |
79 for k in del_keys: | |
80 del self.r_lookup[k] | |
81 self.db[key] = vv | |
82 def delete(self, key): | |
83 del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key] | |
84 for k in del_keys: | |
85 del self.r_lookup[k] | |
86 del self.db[key] | |
87 def checkpoint(self): | |
88 """Potentially pass control to another greenlet/tasklet that could potentially | |
89 serialize this (calling) greenlet/tasklet using cPickle. | |
90 """ | |
91 pass | |
92 | |
93 def sync(self, pull=True, push=True): | |
94 """Synchronise local changes with a master version (if applicable). | |
95 """ | |
96 pass | |
97 | |
98 def open(self, filename): | |
99 """Return a file-handle to a file that can be synced with a server""" | |
100 #todo - save references / proxies of the file objects returned here | |
101 # and sync them with a server if they are closed | |
102 return open(os.path.join(self.rootdir, filename)) | |
103 | |
104 def open_unique(self, mode='wb', prefix='uniq_', suffix=''): | |
105 #TODO: use the standard lib algo for this if you can find it. | |
106 if suffix: | |
107 template = prefix+'%06i.'+suffix | |
108 else: | |
109 template = prefix+'%06i' | |
110 while True: | |
111 fname = template%numpy.random.randint(999999) | |
112 path = os.path.join(self.rootdir, fname) | |
113 try: | |
114 open(path).close() | |
115 except IOError: #file not found | |
116 return open(path, mode=mode), path | |
117 | |
118 def memory_ctrl_obj(): | |
119 return CtrlObj(db=dict()) | |
120 | |
121 def directory_ctrl_obj(path, **kwargs): | |
122 raise NotImplementedError() | |
123 | |
124 def mongo_ctrl_obj(connection_args, **kwargs): | |
125 raise NotImplementedError() | |
126 | |
127 def couchdb_ctrl_obj(connection_args, **kwargs): | |
128 raise NotImplementedError() | |
129 | |
130 def jobman_ctrl_obj(connection_args, **kwargs): | |
131 raise NotImplementedError() | |
132 | |
133 | |
134 def _default_values(f): | |
135 """Return a dictionary param -> default value of function `f`'s parameters""" | |
136 default_dict = {} | |
137 func_defaults = f.func_defaults | |
138 if func_defaults: | |
139 first_default_pos = f.func_code.co_argcount-len(f.func_defaults) | |
140 params_with_defaults = f.func_code.co_varnames[first_default_pos:f.func_code.co_argcount] | |
141 rval = dict(zip(params_with_defaults, f.func_defaults)) | |
142 else: | |
143 rval = {} | |
144 return rval | |
145 | |
146 def test_default_values(): | |
147 | |
148 def f(a): pass | |
149 assert _default_values(f) == {} | |
150 | |
151 def f(a, b=1): | |
152 aa = 5 | |
153 assert _default_values(f) == dict(b=1) | |
154 | |
155 def f(a, b=1, c=2, *args, **kwargs): | |
156 e = b+c | |
157 return e | |
158 assert _default_values(f) == dict(b=1, c=2) | |
159 | |
160 def _arg_assignment(f, args, kwargs): | |
161 # make a dictionary from args and kwargs that contains all the arguments to f and their | |
162 # values | |
163 assignment = dict() | |
164 | |
165 params = f.func_code.co_varnames[:f.func_code.co_argcount] #CORRECT? | |
166 | |
167 f_accepts_varargs = f.func_code.co_flags & CO_VARARGS | |
168 f_accepts_kwargs = f.func_code.co_flags & CO_VARKEYWORDS | |
169 | |
170 if f_accepts_varargs: | |
171 raise NotImplementedError() | |
172 if f_accepts_kwargs: | |
173 raise NotImplementedError() | |
174 | |
175 # first add positional arguments | |
176 #TODO: what if f accepts a '*args' or similar? | |
177 assert len(args) <= f.func_code.co_argcount | |
178 for i, a in enumerate(args): | |
179 assignment[f.func_code.co_varnames[i]] = a # CORRECT?? | |
180 | |
181 # next add kw arguments | |
182 for k,v in kwargs.iteritems(): | |
183 if k in assignment: | |
184 #TODO: match Python error | |
185 raise TypeError('duplicate argument provided for parameter', k) | |
186 | |
187 if (not f_accepts_kwargs) and (k not in params): | |
188 #TODO: match Python error | |
189 raise TypeError('invalid keyword argument', k) | |
190 | |
191 assignment[k] = v | |
192 | |
193 # finally add default arguments for any remaining parameters | |
194 for k,v in _default_values(f).iteritems(): | |
195 if k in assignment: | |
196 pass # this argument has [already] been specified | |
197 else: | |
198 assignment[k] = v | |
199 | |
200 # TODO | |
201 # check that the assignment covers all parameters without default values | |
202 | |
203 # TODO | |
204 # check that the assignment includes no extra variables if f does not accept a '**' | |
205 # parameter. | |
206 | |
207 return assignment | |
208 | |
209 def test_arg_assignment(): | |
210 #TODO: check cases that should cause errors | |
211 # - doubly-specified arguments, | |
212 # - insufficient arguments | |
213 | |
214 def f():pass | |
215 assert _arg_assignment(f, (), {}) == {} | |
216 def f(a):pass | |
217 assert _arg_assignment(f, (1,), {}) == {'a':1} | |
218 def f(a):pass | |
219 assert _arg_assignment(f, (), {'a':1}) == {'a':1} | |
220 | |
221 def f(a=1):pass | |
222 assert _arg_assignment(f, (), {}) == {'a':1} | |
223 def f(a=1):pass | |
224 assert _arg_assignment(f, (2,), {}) == {'a':2} | |
225 def f(a=1):pass | |
226 assert _arg_assignment(f, (), {'a':2}) == {'a':2} | |
227 def f(a=1):pass | |
228 assert _arg_assignment(f, (), {'a':2}) == {'a':2} | |
229 | |
230 def f(b, a=1): pass | |
231 assert _arg_assignment(f, (3,4), {}) == {'b':3, 'a':4} | |
232 def f(b, a=1): pass | |
233 assert _arg_assignment(f, (3,), {'a':4}) == {'b':3, 'a':4} | |
234 def f(b, a=1): pass | |
235 assert _arg_assignment(f, (), {'b':3,'a':4}) == {'b':3, 'a':4} | |
236 def f(b, a=1): pass | |
237 assert _arg_assignment(f, (), {'b':3}) == {'b':3, 'a':1} | |
238 def f(b, a=1): a0=6 | |
239 assert _arg_assignment(f, (2,), {}) == {'b':2, 'a':1} | |
240 | |
241 if 0: | |
242 def test_arg_assignment_w_varargs(): | |
243 def f(b, c=1, *a, **kw): z=5 | |
244 assert _arg_assignment(f, (3,), {}) == {'b':3, 'c':1, 'a':(), 'kw':{}} | |
245 | |
246 | |
247 class CtrlObjCacheWrapper(object): | |
248 | |
249 @classmethod | |
250 def decorate(cls, *args, **kwargs): | |
251 self = cls(*args, **kwargs) | |
252 def rval(f): | |
253 self.f = f | |
254 return rval | |
255 def parse_args(self, args, kwargs): | |
256 """Return key, f_args, f_kwargs, by removing ctrl-cache related flags. | |
257 | |
258 The key is None or a hashable pair that identifies all the arguments to the function. | |
259 """ | |
260 ctrl_args = dict( | |
261 ctrl = None, | |
262 ctrl_ignore_cache=False, | |
263 ctrl_force_shallow_recompute=False, | |
264 ctrl_force_deep_recompute=False, | |
265 ) | |
266 | |
267 # remove the ctrl and ctrl_* arguments | |
268 # because they are not meant to be passed to 'f' | |
269 ctrl_kwds = [(k,v) for (k,v) in kwargs.iteritems() | |
270 if k.startswith('ctrl')] | |
271 ctrl_args.update(dict(ctrl_kwds)) | |
272 f_kwds = [(k,v) for (k,v) in kwargs.iteritems() | |
273 if not k.startswith('ctrl')] | |
274 | |
275 # assignment is a dictionary with a complete specification of the effective arguments to f | |
276 # including default values, varargs, and varkwargs. | |
277 assignment = _arg_assignment(self.f, args, dict(f_kwds)) | |
278 | |
279 assignment_items = assignment.items() | |
280 assignment_items.sort() #canonical ordering for parameters | |
281 | |
282 # replace argument values with explicitly provided keys | |
283 assignment_key = [(k, kwargs.get('ctrl_key_%s'%k, v)) | |
284 for (k,v) in assignment_items] | |
285 | |
286 rval_key = ('fn_cache', self.f, tuple(assignment_key)) | |
287 try: | |
288 hash(rval_key) | |
289 except: | |
290 rval_key = None | |
291 return rval_key, assignment, {}, ctrl_args | |
292 | |
293 def __doc__(self): | |
294 #TODO: Add documentation from self.f | |
295 return """ | |
296 Optional magic kwargs: | |
297 ctrl - use this handle for cache/checkpointing | |
298 ctrl_key_%(paramname)s - specify a key to use for a cache lookup of this parameter | |
299 ctrl_ignore_cache - completely ignore the cache (but checkpointing can still work) | |
300 ctrl_force_shallow_recompute - refresh the cache (but not of sub-calls) | |
301 ctrl_force_deep_recompute - recursively refresh the cache | |
302 ctrl_nocopy - skip the usual copy of a cached return value | |
303 """ | |
304 def __call__(self, *args, **kwargs): | |
305 # N.B. | |
306 # ctrl_force_deep_recompute | |
307 # can work by inspecting the call stack | |
308 # if any parent frame has a special variable set (e.g. _child_frame_ctrl_force_deep_recompute) | |
309 # then it means this is a ctrl_force_deep_recompute too. | |
310 key, f_args, f_kwargs, ctrl_args = self.parse_args(args, kwargs) | |
311 | |
312 ctrl = ctrl_args['ctrl'] | |
313 if ctrl is None or ctrl_args['ctrl_ignore_cache']: | |
314 return self.f(*f_args, **f_kwargs) | |
315 if key: | |
316 try: | |
317 return self.get_cached_val(ctrl, key) | |
318 except KeyError: | |
319 pass | |
320 f_rval = self.f(*f_args, **f_kwargs) | |
321 if key: | |
322 f_rval = self.cache_val(ctrl, key, f_rval) | |
323 return f_rval | |
324 | |
325 def get_cached_val(self, ctrl, key): | |
326 return ctrl.get(key) | |
327 def cache_val(self, ctrl, key, val): | |
328 ctrl.set(key, val) | |
329 return val | |
330 | |
331 class NumpyCacheCtrl(CtrlObjCacheWrapper): | |
332 def get_cached_val(self, ctrl, key): | |
333 filename = ctrl.get(key)['npy_filename'] | |
334 return numpy.load(filename) | |
335 def cache_val(self, ctrl, key, val): | |
336 try: | |
337 filename = ctrl.get(key) | |
338 except KeyError: | |
339 handle, filename = ctrl.open_uniq() | |
340 handle.close() | |
341 ctrl.set(key, dict(npy_filename=filename)) | |
342 numpy.save(filename, val) | |
343 return val | |
344 | |
345 class PickleCacheCtrl(CtrlObjCacheWrapper): | |
346 def __init__(self, protocol=0, **kwargs): | |
347 self.protocol=protocol | |
348 super(PickleCacheCtrl, self).__init__(**kwargs) | |
349 def get_cached_val(self, ctrl, key): | |
350 return cPickle.loads(ctrl.get(key)['cPickle_str']) | |
351 def cache_val(self, ctrl, key, val): | |
352 ctrl.set(key, dict(cPickle_str=cPickle.dumps(val))) | |
353 return val | |
354 | |
355 @NumpyCacheCtrl.decorate() | |
356 def get_raw_data(rows, cols, seed=67273): | |
357 return numpy.random.RandomState(seed).randn(rows, cols) | |
358 | |
359 @NumpyCacheCtrl.decorate() | |
360 def get_whitened_dataset(X, pca, max_components=5): | |
361 return X[:,:max_components] | |
362 | |
363 @PickleCacheCtrl.decorate(protocol=-1) | |
364 def get_pca(X, max_components=100): | |
365 return dict( | |
366 mean=0, | |
367 eigvals=numpy.ones(X.shape[1]), | |
368 eigvecs=numpy.identity(X.shape[1]) | |
369 ) | |
370 | |
371 @PickleCacheCtrl.decorate(protocol=-1) | |
372 def train_mean_var_model(data, ctrl): | |
373 mean = numpy.zeros(data.shape[1]) | |
374 meansq = numpy.zeros(data.shape[1]) | |
375 for i in xrange(data.shape[0]): | |
376 alpha = 1.0 / (i+1) | |
377 mean += (1-alpha) * mean + data[i] * alpha | |
378 meansq += (1-alpha) * meansq + (data[i]**2) * alpha | |
379 ctrl.checkpoint() | |
380 return (mean, meansq) | |
381 | |
382 def test_run_experiment(): | |
383 | |
384 # Could use db, or filesystem, or both, etc. | |
385 # There would be generic ones, but the experimenter should be very aware of what is being | |
386 # cached where, when, and how. This is how results are stored and retrieved after all. | |
387 # Cluster-friendly jobs should not use local files directly, but should store cached | |
388 # computations and results to such a database. | |
389 # Different jobs should avoid using the same keys in the database because coordinating | |
390 # writes is difficult, and conflicts will inevitably arise. | |
391 ctrl = memory_ctrl_obj() | |
392 | |
393 raw_data = get_raw_data(ctrl=ctrl) | |
394 raw_data_key = ctrl.get_key(raw_data) | |
395 | |
396 pca = get_pca(raw_data, max_components=30, ctrl=ctrl) | |
397 whitened_data = get_whitened_dataset(raw_data, pca, ctrl=ctrl) | |
398 | |
399 mean_var = train_mean_var_model( | |
400 data=whitened_data+66, | |
401 ctrl=ctrl, | |
402 ctrl_key_data=whitened_data) #tell that the temporary is tied to whitened_data | |
403 | |
404 mean, var = mean_var | |
405 | |
406 #TODO: Test that the cache actually worked!! | |
407 | |
408 |