Mercurial > pylearn
view doc/v2_planning/arch_src/checkpoint_JB.py @ 1336:09ad2a4f663c
adding new idea to arch_src
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 18 Oct 2010 19:31:17 -0400 |
parents | |
children |
line wrap: on
line source
import copy as copy_module #TODO: use logging.info to report cache hits/ misses CO_VARARGS = 0x0004 CO_VARKEYWORDS = 0x0008 class mem_db(dict): # A key->document dictionary. # A "document" is itself a dictionary. # A "document" can be a small or large object, but it cannot be partially retrieved. # This simple data structure is used in pylearn to cache intermediate reults between # several process invocations. pass class UNSPECIFIED(object): pass class CtrlObj(object): """ Job control API. This interface makes it easier to break a logical program into pieces that can be executed by several different processes, either serially or in parallel. The base class provides decorators to simplify some common cache patterns: - cache_pickle to cache arbitrary return values using the pickle mechanism - cache_dict to cache dict return values directly using the document db - cache_numpy to cache [single] numpy ndarray rvals in a way that supports memmapping of large arrays. Authors are encouraged to use these when they apply, but should feel free to implement other cache logic when these standard ones are lacking using the CtorlObj.get() and CtorlObj.set() methods. """ def __init__(self, rootdir, db, autosync): self.db = db self.r_lookup = {} self.autosync=autosync def get(self, key, default_val=UNSPECIFIED, copy=True): # Default to return a COPY because a self.set() is required to make a change persistent. # Inplace changes that the CtrlObj does not know about (via self.set()) will not be saved. try: val = self.db[key] except: if default_val is not UNSPECIFIED: # return default_val, but do not add it to the r_lookup object # since looking up that key in the future would not retrieve default_val return default_val else: raise if copy: rval = copy_module.deepcopy(val) else: rval = val self.r_lookup[id(rval)] = key return rval def get_key(self, val): """Return the key that retrieved `val`. This is useful for specifying cache keys for unhashable (e.g. numpy) objects that happen to be stored in the db. """ return self.r_lookup[id(val)] def set(self, key, val): vv = dict(val) if self.db.get(key, None) not in (val, None): del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key] for k in del_keys: del self.r_lookup[k] self.db[key] = vv def delete(self, key): del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key] for k in del_keys: del self.r_lookup[k] del self.db[key] def checkpoint(self): """Potentially pass control to another greenlet/tasklet that could potentially serialize this (calling) greenlet/tasklet using cPickle. """ pass def sync(self, pull=True, push=True): """Synchronise local changes with a master version (if applicable). """ pass def open(self, filename): """Return a file-handle to a file that can be synced with a server""" #todo - save references / proxies of the file objects returned here # and sync them with a server if they are closed return open(os.path.join(self.rootdir, filename)) def open_unique(self, mode='wb', prefix='uniq_', suffix=''): #TODO: use the standard lib algo for this if you can find it. if suffix: template = prefix+'%06i.'+suffix else: template = prefix+'%06i' while True: fname = template%numpy.random.randint(999999) path = os.path.join(self.rootdir, fname) try: open(path).close() except IOError: #file not found return open(path, mode=mode), path def memory_ctrl_obj(): return CtrlObj(db=dict()) def directory_ctrl_obj(path, **kwargs): raise NotImplementedError() def mongo_ctrl_obj(connection_args, **kwargs): raise NotImplementedError() def couchdb_ctrl_obj(connection_args, **kwargs): raise NotImplementedError() def jobman_ctrl_obj(connection_args, **kwargs): raise NotImplementedError() def _default_values(f): """Return a dictionary param -> default value of function `f`'s parameters""" default_dict = {} func_defaults = f.func_defaults if func_defaults: first_default_pos = f.func_code.co_argcount-len(f.func_defaults) params_with_defaults = f.func_code.co_varnames[first_default_pos:f.func_code.co_argcount] rval = dict(zip(params_with_defaults, f.func_defaults)) else: rval = {} return rval def test_default_values(): def f(a): pass assert _default_values(f) == {} def f(a, b=1): aa = 5 assert _default_values(f) == dict(b=1) def f(a, b=1, c=2, *args, **kwargs): e = b+c return e assert _default_values(f) == dict(b=1, c=2) def _arg_assignment(f, args, kwargs): # make a dictionary from args and kwargs that contains all the arguments to f and their # values assignment = dict() params = f.func_code.co_varnames[:f.func_code.co_argcount] #CORRECT? f_accepts_varargs = f.func_code.co_flags & CO_VARARGS f_accepts_kwargs = f.func_code.co_flags & CO_VARKEYWORDS if f_accepts_varargs: raise NotImplementedError() if f_accepts_kwargs: raise NotImplementedError() # first add positional arguments #TODO: what if f accepts a '*args' or similar? assert len(args) <= f.func_code.co_argcount for i, a in enumerate(args): assignment[f.func_code.co_varnames[i]] = a # CORRECT?? # next add kw arguments for k,v in kwargs.iteritems(): if k in assignment: #TODO: match Python error raise TypeError('duplicate argument provided for parameter', k) if (not f_accepts_kwargs) and (k not in params): #TODO: match Python error raise TypeError('invalid keyword argument', k) assignment[k] = v # finally add default arguments for any remaining parameters for k,v in _default_values(f).iteritems(): if k in assignment: pass # this argument has [already] been specified else: assignment[k] = v # TODO # check that the assignment covers all parameters without default values # TODO # check that the assignment includes no extra variables if f does not accept a '**' # parameter. return assignment def test_arg_assignment(): #TODO: check cases that should cause errors # - doubly-specified arguments, # - insufficient arguments def f():pass assert _arg_assignment(f, (), {}) == {} def f(a):pass assert _arg_assignment(f, (1,), {}) == {'a':1} def f(a):pass assert _arg_assignment(f, (), {'a':1}) == {'a':1} def f(a=1):pass assert _arg_assignment(f, (), {}) == {'a':1} def f(a=1):pass assert _arg_assignment(f, (2,), {}) == {'a':2} def f(a=1):pass assert _arg_assignment(f, (), {'a':2}) == {'a':2} def f(a=1):pass assert _arg_assignment(f, (), {'a':2}) == {'a':2} def f(b, a=1): pass assert _arg_assignment(f, (3,4), {}) == {'b':3, 'a':4} def f(b, a=1): pass assert _arg_assignment(f, (3,), {'a':4}) == {'b':3, 'a':4} def f(b, a=1): pass assert _arg_assignment(f, (), {'b':3,'a':4}) == {'b':3, 'a':4} def f(b, a=1): pass assert _arg_assignment(f, (), {'b':3}) == {'b':3, 'a':1} def f(b, a=1): a0=6 assert _arg_assignment(f, (2,), {}) == {'b':2, 'a':1} if 0: def test_arg_assignment_w_varargs(): def f(b, c=1, *a, **kw): z=5 assert _arg_assignment(f, (3,), {}) == {'b':3, 'c':1, 'a':(), 'kw':{}} class CtrlObjCacheWrapper(object): @classmethod def decorate(cls, *args, **kwargs): self = cls(*args, **kwargs) def rval(f): self.f = f return rval def parse_args(self, args, kwargs): """Return key, f_args, f_kwargs, by removing ctrl-cache related flags. The key is None or a hashable pair that identifies all the arguments to the function. """ ctrl_args = dict( ctrl = None, ctrl_ignore_cache=False, ctrl_force_shallow_recompute=False, ctrl_force_deep_recompute=False, ) # remove the ctrl and ctrl_* arguments # because they are not meant to be passed to 'f' ctrl_kwds = [(k,v) for (k,v) in kwargs.iteritems() if k.startswith('ctrl')] ctrl_args.update(dict(ctrl_kwds)) f_kwds = [(k,v) for (k,v) in kwargs.iteritems() if not k.startswith('ctrl')] # assignment is a dictionary with a complete specification of the effective arguments to f # including default values, varargs, and varkwargs. assignment = _arg_assignment(self.f, args, dict(f_kwds)) assignment_items = assignment.items() assignment_items.sort() #canonical ordering for parameters # replace argument values with explicitly provided keys assignment_key = [(k, kwargs.get('ctrl_key_%s'%k, v)) for (k,v) in assignment_items] rval_key = ('fn_cache', self.f, tuple(assignment_key)) try: hash(rval_key) except: rval_key = None return rval_key, assignment, {}, ctrl_args def __doc__(self): #TODO: Add documentation from self.f return """ Optional magic kwargs: ctrl - use this handle for cache/checkpointing ctrl_key_%(paramname)s - specify a key to use for a cache lookup of this parameter ctrl_ignore_cache - completely ignore the cache (but checkpointing can still work) ctrl_force_shallow_recompute - refresh the cache (but not of sub-calls) ctrl_force_deep_recompute - recursively refresh the cache ctrl_nocopy - skip the usual copy of a cached return value """ def __call__(self, *args, **kwargs): # N.B. # ctrl_force_deep_recompute # can work by inspecting the call stack # if any parent frame has a special variable set (e.g. _child_frame_ctrl_force_deep_recompute) # then it means this is a ctrl_force_deep_recompute too. key, f_args, f_kwargs, ctrl_args = self.parse_args(args, kwargs) ctrl = ctrl_args['ctrl'] if ctrl is None or ctrl_args['ctrl_ignore_cache']: return self.f(*f_args, **f_kwargs) if key: try: return self.get_cached_val(ctrl, key) except KeyError: pass f_rval = self.f(*f_args, **f_kwargs) if key: f_rval = self.cache_val(ctrl, key, f_rval) return f_rval def get_cached_val(self, ctrl, key): return ctrl.get(key) def cache_val(self, ctrl, key, val): ctrl.set(key, val) return val class NumpyCacheCtrl(CtrlObjCacheWrapper): def get_cached_val(self, ctrl, key): filename = ctrl.get(key)['npy_filename'] return numpy.load(filename) def cache_val(self, ctrl, key, val): try: filename = ctrl.get(key) except KeyError: handle, filename = ctrl.open_uniq() handle.close() ctrl.set(key, dict(npy_filename=filename)) numpy.save(filename, val) return val class PickleCacheCtrl(CtrlObjCacheWrapper): def __init__(self, protocol=0, **kwargs): self.protocol=protocol super(PickleCacheCtrl, self).__init__(**kwargs) def get_cached_val(self, ctrl, key): return cPickle.loads(ctrl.get(key)['cPickle_str']) def cache_val(self, ctrl, key, val): ctrl.set(key, dict(cPickle_str=cPickle.dumps(val))) return val @NumpyCacheCtrl.decorate() def get_raw_data(rows, cols, seed=67273): return numpy.random.RandomState(seed).randn(rows, cols) @NumpyCacheCtrl.decorate() def get_whitened_dataset(X, pca, max_components=5): return X[:,:max_components] @PickleCacheCtrl.decorate(protocol=-1) def get_pca(X, max_components=100): return dict( mean=0, eigvals=numpy.ones(X.shape[1]), eigvecs=numpy.identity(X.shape[1]) ) @PickleCacheCtrl.decorate(protocol=-1) def train_mean_var_model(data, ctrl): mean = numpy.zeros(data.shape[1]) meansq = numpy.zeros(data.shape[1]) for i in xrange(data.shape[0]): alpha = 1.0 / (i+1) mean += (1-alpha) * mean + data[i] * alpha meansq += (1-alpha) * meansq + (data[i]**2) * alpha ctrl.checkpoint() return (mean, meansq) def test_run_experiment(): # Could use db, or filesystem, or both, etc. # There would be generic ones, but the experimenter should be very aware of what is being # cached where, when, and how. This is how results are stored and retrieved after all. # Cluster-friendly jobs should not use local files directly, but should store cached # computations and results to such a database. # Different jobs should avoid using the same keys in the database because coordinating # writes is difficult, and conflicts will inevitably arise. ctrl = memory_ctrl_obj() raw_data = get_raw_data(ctrl=ctrl) raw_data_key = ctrl.get_key(raw_data) pca = get_pca(raw_data, max_components=30, ctrl=ctrl) whitened_data = get_whitened_dataset(raw_data, pca, ctrl=ctrl) mean_var = train_mean_var_model( data=whitened_data+66, ctrl=ctrl, ctrl_key_data=whitened_data) #tell that the temporary is tied to whitened_data mean, var = mean_var #TODO: Test that the cache actually worked!!