# HG changeset patch # User James Bergstra # Date 1287675756 14400 # Node ID 9e898b2b98e0d06a2e9d6e7d14d00b7d30e8310c # Parent 5e893cd6daae3441ce1d883c4ffbbfdde52a2e2e removed checkpoint draft diff -r 5e893cd6daae -r 9e898b2b98e0 doc/v2_planning/arch_src/checkpoint_JB.py --- a/doc/v2_planning/arch_src/checkpoint_JB.py Wed Oct 27 10:37:18 2010 -0400 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,408 +0,0 @@ -import copy as copy_module - -#TODO: use logging.info to report cache hits/ misses - -CO_VARARGS = 0x0004 -CO_VARKEYWORDS = 0x0008 - -class mem_db(dict): - # A key->document dictionary. - # A "document" is itself a dictionary. - - # A "document" can be a small or large object, but it cannot be partially retrieved. - - # This simple data structure is used in pylearn to cache intermediate reults between - # several process invocations. - pass - -class UNSPECIFIED(object): - pass - -class CtrlObj(object): - """ - Job control API. - - This interface makes it easier to break a logical program into pieces that can be - executed by several different processes, either serially or in parallel. - - - The base class provides decorators to simplify some common cache patterns: - - cache_pickle to cache arbitrary return values using the pickle mechanism - - - cache_dict to cache dict return values directly using the document db - - - cache_numpy to cache [single] numpy ndarray rvals in a way that supports memmapping of - large arrays. - - Authors are encouraged to use these when they apply, but should feel free to implement - other cache logic when these standard ones are lacking using the CtorlObj.get() and - CtorlObj.set() methods. - - - """ - - def __init__(self, rootdir, db, autosync): - self.db = db - self.r_lookup = {} - self.autosync=autosync - - def get(self, key, default_val=UNSPECIFIED, copy=True): - # Default to return a COPY because a self.set() is required to make a change persistent. - # Inplace changes that the CtrlObj does not know about (via self.set()) will not be saved. - try: - val = self.db[key] - except: - if default_val is not UNSPECIFIED: - # return default_val, but do not add it to the r_lookup object - # since looking up that key in the future would not retrieve default_val - return default_val - else: - raise - if copy: - rval = copy_module.deepcopy(val) - else: - rval = val - self.r_lookup[id(rval)] = key - return rval - - def get_key(self, val): - """Return the key that retrieved `val`. - - This is useful for specifying cache keys for unhashable (e.g. numpy) objects that - happen to be stored in the db. - """ - return self.r_lookup[id(val)] - def set(self, key, val): - vv = dict(val) - if self.db.get(key, None) not in (val, None): - del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key] - for k in del_keys: - del self.r_lookup[k] - self.db[key] = vv - def delete(self, key): - del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key] - for k in del_keys: - del self.r_lookup[k] - del self.db[key] - def checkpoint(self): - """Potentially pass control to another greenlet/tasklet that could potentially - serialize this (calling) greenlet/tasklet using cPickle. - """ - pass - - def sync(self, pull=True, push=True): - """Synchronise local changes with a master version (if applicable). - """ - pass - - def open(self, filename): - """Return a file-handle to a file that can be synced with a server""" - #todo - save references / proxies of the file objects returned here - # and sync them with a server if they are closed - return open(os.path.join(self.rootdir, filename)) - - def open_unique(self, mode='wb', prefix='uniq_', suffix=''): - #TODO: use the standard lib algo for this if you can find it. - if suffix: - template = prefix+'%06i.'+suffix - else: - template = prefix+'%06i' - while True: - fname = template%numpy.random.randint(999999) - path = os.path.join(self.rootdir, fname) - try: - open(path).close() - except IOError: #file not found - return open(path, mode=mode), path - -def memory_ctrl_obj(): - return CtrlObj(db=dict()) - -def directory_ctrl_obj(path, **kwargs): - raise NotImplementedError() - -def mongo_ctrl_obj(connection_args, **kwargs): - raise NotImplementedError() - -def couchdb_ctrl_obj(connection_args, **kwargs): - raise NotImplementedError() - -def jobman_ctrl_obj(connection_args, **kwargs): - raise NotImplementedError() - - -def _default_values(f): - """Return a dictionary param -> default value of function `f`'s parameters""" - default_dict = {} - func_defaults = f.func_defaults - if func_defaults: - first_default_pos = f.func_code.co_argcount-len(f.func_defaults) - params_with_defaults = f.func_code.co_varnames[first_default_pos:f.func_code.co_argcount] - rval = dict(zip(params_with_defaults, f.func_defaults)) - else: - rval = {} - return rval - -def test_default_values(): - - def f(a): pass - assert _default_values(f) == {} - - def f(a, b=1): - aa = 5 - assert _default_values(f) == dict(b=1) - - def f(a, b=1, c=2, *args, **kwargs): - e = b+c - return e - assert _default_values(f) == dict(b=1, c=2) - -def _arg_assignment(f, args, kwargs): - # make a dictionary from args and kwargs that contains all the arguments to f and their - # values - assignment = dict() - - params = f.func_code.co_varnames[:f.func_code.co_argcount] #CORRECT? - - f_accepts_varargs = f.func_code.co_flags & CO_VARARGS - f_accepts_kwargs = f.func_code.co_flags & CO_VARKEYWORDS - - if f_accepts_varargs: - raise NotImplementedError() - if f_accepts_kwargs: - raise NotImplementedError() - - # first add positional arguments - #TODO: what if f accepts a '*args' or similar? - assert len(args) <= f.func_code.co_argcount - for i, a in enumerate(args): - assignment[f.func_code.co_varnames[i]] = a # CORRECT?? - - # next add kw arguments - for k,v in kwargs.iteritems(): - if k in assignment: - #TODO: match Python error - raise TypeError('duplicate argument provided for parameter', k) - - if (not f_accepts_kwargs) and (k not in params): - #TODO: match Python error - raise TypeError('invalid keyword argument', k) - - assignment[k] = v - - # finally add default arguments for any remaining parameters - for k,v in _default_values(f).iteritems(): - if k in assignment: - pass # this argument has [already] been specified - else: - assignment[k] = v - - # TODO - # check that the assignment covers all parameters without default values - - # TODO - # check that the assignment includes no extra variables if f does not accept a '**' - # parameter. - - return assignment - -def test_arg_assignment(): - #TODO: check cases that should cause errors - # - doubly-specified arguments, - # - insufficient arguments - - def f():pass - assert _arg_assignment(f, (), {}) == {} - def f(a):pass - assert _arg_assignment(f, (1,), {}) == {'a':1} - def f(a):pass - assert _arg_assignment(f, (), {'a':1}) == {'a':1} - - def f(a=1):pass - assert _arg_assignment(f, (), {}) == {'a':1} - def f(a=1):pass - assert _arg_assignment(f, (2,), {}) == {'a':2} - def f(a=1):pass - assert _arg_assignment(f, (), {'a':2}) == {'a':2} - def f(a=1):pass - assert _arg_assignment(f, (), {'a':2}) == {'a':2} - - def f(b, a=1): pass - assert _arg_assignment(f, (3,4), {}) == {'b':3, 'a':4} - def f(b, a=1): pass - assert _arg_assignment(f, (3,), {'a':4}) == {'b':3, 'a':4} - def f(b, a=1): pass - assert _arg_assignment(f, (), {'b':3,'a':4}) == {'b':3, 'a':4} - def f(b, a=1): pass - assert _arg_assignment(f, (), {'b':3}) == {'b':3, 'a':1} - def f(b, a=1): a0=6 - assert _arg_assignment(f, (2,), {}) == {'b':2, 'a':1} - -if 0: - def test_arg_assignment_w_varargs(): - def f(b, c=1, *a, **kw): z=5 - assert _arg_assignment(f, (3,), {}) == {'b':3, 'c':1, 'a':(), 'kw':{}} - - -class CtrlObjCacheWrapper(object): - - @classmethod - def decorate(cls, *args, **kwargs): - self = cls(*args, **kwargs) - def rval(f): - self.f = f - return rval - def parse_args(self, args, kwargs): - """Return key, f_args, f_kwargs, by removing ctrl-cache related flags. - - The key is None or a hashable pair that identifies all the arguments to the function. - """ - ctrl_args = dict( - ctrl = None, - ctrl_ignore_cache=False, - ctrl_force_shallow_recompute=False, - ctrl_force_deep_recompute=False, - ) - - # remove the ctrl and ctrl_* arguments - # because they are not meant to be passed to 'f' - ctrl_kwds = [(k,v) for (k,v) in kwargs.iteritems() - if k.startswith('ctrl')] - ctrl_args.update(dict(ctrl_kwds)) - f_kwds = [(k,v) for (k,v) in kwargs.iteritems() - if not k.startswith('ctrl')] - - # assignment is a dictionary with a complete specification of the effective arguments to f - # including default values, varargs, and varkwargs. - assignment = _arg_assignment(self.f, args, dict(f_kwds)) - - assignment_items = assignment.items() - assignment_items.sort() #canonical ordering for parameters - - # replace argument values with explicitly provided keys - assignment_key = [(k, kwargs.get('ctrl_key_%s'%k, v)) - for (k,v) in assignment_items] - - rval_key = ('fn_cache', self.f, tuple(assignment_key)) - try: - hash(rval_key) - except: - rval_key = None - return rval_key, assignment, {}, ctrl_args - - def __doc__(self): - #TODO: Add documentation from self.f - return """ - Optional magic kwargs: - ctrl - use this handle for cache/checkpointing - ctrl_key_%(paramname)s - specify a key to use for a cache lookup of this parameter - ctrl_ignore_cache - completely ignore the cache (but checkpointing can still work) - ctrl_force_shallow_recompute - refresh the cache (but not of sub-calls) - ctrl_force_deep_recompute - recursively refresh the cache - ctrl_nocopy - skip the usual copy of a cached return value - """ - def __call__(self, *args, **kwargs): - # N.B. - # ctrl_force_deep_recompute - # can work by inspecting the call stack - # if any parent frame has a special variable set (e.g. _child_frame_ctrl_force_deep_recompute) - # then it means this is a ctrl_force_deep_recompute too. - key, f_args, f_kwargs, ctrl_args = self.parse_args(args, kwargs) - - ctrl = ctrl_args['ctrl'] - if ctrl is None or ctrl_args['ctrl_ignore_cache']: - return self.f(*f_args, **f_kwargs) - if key: - try: - return self.get_cached_val(ctrl, key) - except KeyError: - pass - f_rval = self.f(*f_args, **f_kwargs) - if key: - f_rval = self.cache_val(ctrl, key, f_rval) - return f_rval - - def get_cached_val(self, ctrl, key): - return ctrl.get(key) - def cache_val(self, ctrl, key, val): - ctrl.set(key, val) - return val - -class NumpyCacheCtrl(CtrlObjCacheWrapper): - def get_cached_val(self, ctrl, key): - filename = ctrl.get(key)['npy_filename'] - return numpy.load(filename) - def cache_val(self, ctrl, key, val): - try: - filename = ctrl.get(key) - except KeyError: - handle, filename = ctrl.open_uniq() - handle.close() - ctrl.set(key, dict(npy_filename=filename)) - numpy.save(filename, val) - return val - -class PickleCacheCtrl(CtrlObjCacheWrapper): - def __init__(self, protocol=0, **kwargs): - self.protocol=protocol - super(PickleCacheCtrl, self).__init__(**kwargs) - def get_cached_val(self, ctrl, key): - return cPickle.loads(ctrl.get(key)['cPickle_str']) - def cache_val(self, ctrl, key, val): - ctrl.set(key, dict(cPickle_str=cPickle.dumps(val))) - return val - -@NumpyCacheCtrl.decorate() -def get_raw_data(rows, cols, seed=67273): - return numpy.random.RandomState(seed).randn(rows, cols) - -@NumpyCacheCtrl.decorate() -def get_whitened_dataset(X, pca, max_components=5): - return X[:,:max_components] - -@PickleCacheCtrl.decorate(protocol=-1) -def get_pca(X, max_components=100): - return dict( - mean=0, - eigvals=numpy.ones(X.shape[1]), - eigvecs=numpy.identity(X.shape[1]) - ) - -@PickleCacheCtrl.decorate(protocol=-1) -def train_mean_var_model(data, ctrl): - mean = numpy.zeros(data.shape[1]) - meansq = numpy.zeros(data.shape[1]) - for i in xrange(data.shape[0]): - alpha = 1.0 / (i+1) - mean += (1-alpha) * mean + data[i] * alpha - meansq += (1-alpha) * meansq + (data[i]**2) * alpha - ctrl.checkpoint() - return (mean, meansq) - -def test_run_experiment(): - - # Could use db, or filesystem, or both, etc. - # There would be generic ones, but the experimenter should be very aware of what is being - # cached where, when, and how. This is how results are stored and retrieved after all. - # Cluster-friendly jobs should not use local files directly, but should store cached - # computations and results to such a database. - # Different jobs should avoid using the same keys in the database because coordinating - # writes is difficult, and conflicts will inevitably arise. - ctrl = memory_ctrl_obj() - - raw_data = get_raw_data(ctrl=ctrl) - raw_data_key = ctrl.get_key(raw_data) - - pca = get_pca(raw_data, max_components=30, ctrl=ctrl) - whitened_data = get_whitened_dataset(raw_data, pca, ctrl=ctrl) - - mean_var = train_mean_var_model( - data=whitened_data+66, - ctrl=ctrl, - ctrl_key_data=whitened_data) #tell that the temporary is tied to whitened_data - - mean, var = mean_var - - #TODO: Test that the cache actually worked!! - -