comparison doc/v2_planning/arch_src/checkpoint_JB.py @ 1336:09ad2a4f663c

adding new idea to arch_src
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 18 Oct 2010 19:31:17 -0400
parents
children
comparison
equal deleted inserted replaced
1335:7c51c0355d86 1336:09ad2a4f663c
1 import copy as copy_module
2
3 #TODO: use logging.info to report cache hits/ misses
4
5 CO_VARARGS = 0x0004
6 CO_VARKEYWORDS = 0x0008
7
8 class mem_db(dict):
9 # A key->document dictionary.
10 # A "document" is itself a dictionary.
11
12 # A "document" can be a small or large object, but it cannot be partially retrieved.
13
14 # This simple data structure is used in pylearn to cache intermediate reults between
15 # several process invocations.
16 pass
17
18 class UNSPECIFIED(object):
19 pass
20
21 class CtrlObj(object):
22 """
23 Job control API.
24
25 This interface makes it easier to break a logical program into pieces that can be
26 executed by several different processes, either serially or in parallel.
27
28
29 The base class provides decorators to simplify some common cache patterns:
30 - cache_pickle to cache arbitrary return values using the pickle mechanism
31
32 - cache_dict to cache dict return values directly using the document db
33
34 - cache_numpy to cache [single] numpy ndarray rvals in a way that supports memmapping of
35 large arrays.
36
37 Authors are encouraged to use these when they apply, but should feel free to implement
38 other cache logic when these standard ones are lacking using the CtorlObj.get() and
39 CtorlObj.set() methods.
40
41
42 """
43
44 def __init__(self, rootdir, db, autosync):
45 self.db = db
46 self.r_lookup = {}
47 self.autosync=autosync
48
49 def get(self, key, default_val=UNSPECIFIED, copy=True):
50 # Default to return a COPY because a self.set() is required to make a change persistent.
51 # Inplace changes that the CtrlObj does not know about (via self.set()) will not be saved.
52 try:
53 val = self.db[key]
54 except:
55 if default_val is not UNSPECIFIED:
56 # return default_val, but do not add it to the r_lookup object
57 # since looking up that key in the future would not retrieve default_val
58 return default_val
59 else:
60 raise
61 if copy:
62 rval = copy_module.deepcopy(val)
63 else:
64 rval = val
65 self.r_lookup[id(rval)] = key
66 return rval
67
68 def get_key(self, val):
69 """Return the key that retrieved `val`.
70
71 This is useful for specifying cache keys for unhashable (e.g. numpy) objects that
72 happen to be stored in the db.
73 """
74 return self.r_lookup[id(val)]
75 def set(self, key, val):
76 vv = dict(val)
77 if self.db.get(key, None) not in (val, None):
78 del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key]
79 for k in del_keys:
80 del self.r_lookup[k]
81 self.db[key] = vv
82 def delete(self, key):
83 del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key]
84 for k in del_keys:
85 del self.r_lookup[k]
86 del self.db[key]
87 def checkpoint(self):
88 """Potentially pass control to another greenlet/tasklet that could potentially
89 serialize this (calling) greenlet/tasklet using cPickle.
90 """
91 pass
92
93 def sync(self, pull=True, push=True):
94 """Synchronise local changes with a master version (if applicable).
95 """
96 pass
97
98 def open(self, filename):
99 """Return a file-handle to a file that can be synced with a server"""
100 #todo - save references / proxies of the file objects returned here
101 # and sync them with a server if they are closed
102 return open(os.path.join(self.rootdir, filename))
103
104 def open_unique(self, mode='wb', prefix='uniq_', suffix=''):
105 #TODO: use the standard lib algo for this if you can find it.
106 if suffix:
107 template = prefix+'%06i.'+suffix
108 else:
109 template = prefix+'%06i'
110 while True:
111 fname = template%numpy.random.randint(999999)
112 path = os.path.join(self.rootdir, fname)
113 try:
114 open(path).close()
115 except IOError: #file not found
116 return open(path, mode=mode), path
117
118 def memory_ctrl_obj():
119 return CtrlObj(db=dict())
120
121 def directory_ctrl_obj(path, **kwargs):
122 raise NotImplementedError()
123
124 def mongo_ctrl_obj(connection_args, **kwargs):
125 raise NotImplementedError()
126
127 def couchdb_ctrl_obj(connection_args, **kwargs):
128 raise NotImplementedError()
129
130 def jobman_ctrl_obj(connection_args, **kwargs):
131 raise NotImplementedError()
132
133
134 def _default_values(f):
135 """Return a dictionary param -> default value of function `f`'s parameters"""
136 default_dict = {}
137 func_defaults = f.func_defaults
138 if func_defaults:
139 first_default_pos = f.func_code.co_argcount-len(f.func_defaults)
140 params_with_defaults = f.func_code.co_varnames[first_default_pos:f.func_code.co_argcount]
141 rval = dict(zip(params_with_defaults, f.func_defaults))
142 else:
143 rval = {}
144 return rval
145
146 def test_default_values():
147
148 def f(a): pass
149 assert _default_values(f) == {}
150
151 def f(a, b=1):
152 aa = 5
153 assert _default_values(f) == dict(b=1)
154
155 def f(a, b=1, c=2, *args, **kwargs):
156 e = b+c
157 return e
158 assert _default_values(f) == dict(b=1, c=2)
159
160 def _arg_assignment(f, args, kwargs):
161 # make a dictionary from args and kwargs that contains all the arguments to f and their
162 # values
163 assignment = dict()
164
165 params = f.func_code.co_varnames[:f.func_code.co_argcount] #CORRECT?
166
167 f_accepts_varargs = f.func_code.co_flags & CO_VARARGS
168 f_accepts_kwargs = f.func_code.co_flags & CO_VARKEYWORDS
169
170 if f_accepts_varargs:
171 raise NotImplementedError()
172 if f_accepts_kwargs:
173 raise NotImplementedError()
174
175 # first add positional arguments
176 #TODO: what if f accepts a '*args' or similar?
177 assert len(args) <= f.func_code.co_argcount
178 for i, a in enumerate(args):
179 assignment[f.func_code.co_varnames[i]] = a # CORRECT??
180
181 # next add kw arguments
182 for k,v in kwargs.iteritems():
183 if k in assignment:
184 #TODO: match Python error
185 raise TypeError('duplicate argument provided for parameter', k)
186
187 if (not f_accepts_kwargs) and (k not in params):
188 #TODO: match Python error
189 raise TypeError('invalid keyword argument', k)
190
191 assignment[k] = v
192
193 # finally add default arguments for any remaining parameters
194 for k,v in _default_values(f).iteritems():
195 if k in assignment:
196 pass # this argument has [already] been specified
197 else:
198 assignment[k] = v
199
200 # TODO
201 # check that the assignment covers all parameters without default values
202
203 # TODO
204 # check that the assignment includes no extra variables if f does not accept a '**'
205 # parameter.
206
207 return assignment
208
209 def test_arg_assignment():
210 #TODO: check cases that should cause errors
211 # - doubly-specified arguments,
212 # - insufficient arguments
213
214 def f():pass
215 assert _arg_assignment(f, (), {}) == {}
216 def f(a):pass
217 assert _arg_assignment(f, (1,), {}) == {'a':1}
218 def f(a):pass
219 assert _arg_assignment(f, (), {'a':1}) == {'a':1}
220
221 def f(a=1):pass
222 assert _arg_assignment(f, (), {}) == {'a':1}
223 def f(a=1):pass
224 assert _arg_assignment(f, (2,), {}) == {'a':2}
225 def f(a=1):pass
226 assert _arg_assignment(f, (), {'a':2}) == {'a':2}
227 def f(a=1):pass
228 assert _arg_assignment(f, (), {'a':2}) == {'a':2}
229
230 def f(b, a=1): pass
231 assert _arg_assignment(f, (3,4), {}) == {'b':3, 'a':4}
232 def f(b, a=1): pass
233 assert _arg_assignment(f, (3,), {'a':4}) == {'b':3, 'a':4}
234 def f(b, a=1): pass
235 assert _arg_assignment(f, (), {'b':3,'a':4}) == {'b':3, 'a':4}
236 def f(b, a=1): pass
237 assert _arg_assignment(f, (), {'b':3}) == {'b':3, 'a':1}
238 def f(b, a=1): a0=6
239 assert _arg_assignment(f, (2,), {}) == {'b':2, 'a':1}
240
241 if 0:
242 def test_arg_assignment_w_varargs():
243 def f(b, c=1, *a, **kw): z=5
244 assert _arg_assignment(f, (3,), {}) == {'b':3, 'c':1, 'a':(), 'kw':{}}
245
246
247 class CtrlObjCacheWrapper(object):
248
249 @classmethod
250 def decorate(cls, *args, **kwargs):
251 self = cls(*args, **kwargs)
252 def rval(f):
253 self.f = f
254 return rval
255 def parse_args(self, args, kwargs):
256 """Return key, f_args, f_kwargs, by removing ctrl-cache related flags.
257
258 The key is None or a hashable pair that identifies all the arguments to the function.
259 """
260 ctrl_args = dict(
261 ctrl = None,
262 ctrl_ignore_cache=False,
263 ctrl_force_shallow_recompute=False,
264 ctrl_force_deep_recompute=False,
265 )
266
267 # remove the ctrl and ctrl_* arguments
268 # because they are not meant to be passed to 'f'
269 ctrl_kwds = [(k,v) for (k,v) in kwargs.iteritems()
270 if k.startswith('ctrl')]
271 ctrl_args.update(dict(ctrl_kwds))
272 f_kwds = [(k,v) for (k,v) in kwargs.iteritems()
273 if not k.startswith('ctrl')]
274
275 # assignment is a dictionary with a complete specification of the effective arguments to f
276 # including default values, varargs, and varkwargs.
277 assignment = _arg_assignment(self.f, args, dict(f_kwds))
278
279 assignment_items = assignment.items()
280 assignment_items.sort() #canonical ordering for parameters
281
282 # replace argument values with explicitly provided keys
283 assignment_key = [(k, kwargs.get('ctrl_key_%s'%k, v))
284 for (k,v) in assignment_items]
285
286 rval_key = ('fn_cache', self.f, tuple(assignment_key))
287 try:
288 hash(rval_key)
289 except:
290 rval_key = None
291 return rval_key, assignment, {}, ctrl_args
292
293 def __doc__(self):
294 #TODO: Add documentation from self.f
295 return """
296 Optional magic kwargs:
297 ctrl - use this handle for cache/checkpointing
298 ctrl_key_%(paramname)s - specify a key to use for a cache lookup of this parameter
299 ctrl_ignore_cache - completely ignore the cache (but checkpointing can still work)
300 ctrl_force_shallow_recompute - refresh the cache (but not of sub-calls)
301 ctrl_force_deep_recompute - recursively refresh the cache
302 ctrl_nocopy - skip the usual copy of a cached return value
303 """
304 def __call__(self, *args, **kwargs):
305 # N.B.
306 # ctrl_force_deep_recompute
307 # can work by inspecting the call stack
308 # if any parent frame has a special variable set (e.g. _child_frame_ctrl_force_deep_recompute)
309 # then it means this is a ctrl_force_deep_recompute too.
310 key, f_args, f_kwargs, ctrl_args = self.parse_args(args, kwargs)
311
312 ctrl = ctrl_args['ctrl']
313 if ctrl is None or ctrl_args['ctrl_ignore_cache']:
314 return self.f(*f_args, **f_kwargs)
315 if key:
316 try:
317 return self.get_cached_val(ctrl, key)
318 except KeyError:
319 pass
320 f_rval = self.f(*f_args, **f_kwargs)
321 if key:
322 f_rval = self.cache_val(ctrl, key, f_rval)
323 return f_rval
324
325 def get_cached_val(self, ctrl, key):
326 return ctrl.get(key)
327 def cache_val(self, ctrl, key, val):
328 ctrl.set(key, val)
329 return val
330
331 class NumpyCacheCtrl(CtrlObjCacheWrapper):
332 def get_cached_val(self, ctrl, key):
333 filename = ctrl.get(key)['npy_filename']
334 return numpy.load(filename)
335 def cache_val(self, ctrl, key, val):
336 try:
337 filename = ctrl.get(key)
338 except KeyError:
339 handle, filename = ctrl.open_uniq()
340 handle.close()
341 ctrl.set(key, dict(npy_filename=filename))
342 numpy.save(filename, val)
343 return val
344
345 class PickleCacheCtrl(CtrlObjCacheWrapper):
346 def __init__(self, protocol=0, **kwargs):
347 self.protocol=protocol
348 super(PickleCacheCtrl, self).__init__(**kwargs)
349 def get_cached_val(self, ctrl, key):
350 return cPickle.loads(ctrl.get(key)['cPickle_str'])
351 def cache_val(self, ctrl, key, val):
352 ctrl.set(key, dict(cPickle_str=cPickle.dumps(val)))
353 return val
354
355 @NumpyCacheCtrl.decorate()
356 def get_raw_data(rows, cols, seed=67273):
357 return numpy.random.RandomState(seed).randn(rows, cols)
358
359 @NumpyCacheCtrl.decorate()
360 def get_whitened_dataset(X, pca, max_components=5):
361 return X[:,:max_components]
362
363 @PickleCacheCtrl.decorate(protocol=-1)
364 def get_pca(X, max_components=100):
365 return dict(
366 mean=0,
367 eigvals=numpy.ones(X.shape[1]),
368 eigvecs=numpy.identity(X.shape[1])
369 )
370
371 @PickleCacheCtrl.decorate(protocol=-1)
372 def train_mean_var_model(data, ctrl):
373 mean = numpy.zeros(data.shape[1])
374 meansq = numpy.zeros(data.shape[1])
375 for i in xrange(data.shape[0]):
376 alpha = 1.0 / (i+1)
377 mean += (1-alpha) * mean + data[i] * alpha
378 meansq += (1-alpha) * meansq + (data[i]**2) * alpha
379 ctrl.checkpoint()
380 return (mean, meansq)
381
382 def test_run_experiment():
383
384 # Could use db, or filesystem, or both, etc.
385 # There would be generic ones, but the experimenter should be very aware of what is being
386 # cached where, when, and how. This is how results are stored and retrieved after all.
387 # Cluster-friendly jobs should not use local files directly, but should store cached
388 # computations and results to such a database.
389 # Different jobs should avoid using the same keys in the database because coordinating
390 # writes is difficult, and conflicts will inevitably arise.
391 ctrl = memory_ctrl_obj()
392
393 raw_data = get_raw_data(ctrl=ctrl)
394 raw_data_key = ctrl.get_key(raw_data)
395
396 pca = get_pca(raw_data, max_components=30, ctrl=ctrl)
397 whitened_data = get_whitened_dataset(raw_data, pca, ctrl=ctrl)
398
399 mean_var = train_mean_var_model(
400 data=whitened_data+66,
401 ctrl=ctrl,
402 ctrl_key_data=whitened_data) #tell that the temporary is tied to whitened_data
403
404 mean, var = mean_var
405
406 #TODO: Test that the cache actually worked!!
407
408