# HG changeset patch # User James Bergstra # Date 1287444677 14400 # Node ID 09ad2a4f663cfc019aa5a9dbf9fa1d165d236280 # Parent 7c51c0355d863d0d282864e47a5dac4d24be3868 adding new idea to arch_src diff -r 7c51c0355d86 -r 09ad2a4f663c doc/v2_planning/arch_src/checkpoint_JB.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/doc/v2_planning/arch_src/checkpoint_JB.py Mon Oct 18 19:31:17 2010 -0400 @@ -0,0 +1,408 @@ +import copy as copy_module + +#TODO: use logging.info to report cache hits/ misses + +CO_VARARGS = 0x0004 +CO_VARKEYWORDS = 0x0008 + +class mem_db(dict): + # A key->document dictionary. + # A "document" is itself a dictionary. + + # A "document" can be a small or large object, but it cannot be partially retrieved. + + # This simple data structure is used in pylearn to cache intermediate reults between + # several process invocations. + pass + +class UNSPECIFIED(object): + pass + +class CtrlObj(object): + """ + Job control API. + + This interface makes it easier to break a logical program into pieces that can be + executed by several different processes, either serially or in parallel. + + + The base class provides decorators to simplify some common cache patterns: + - cache_pickle to cache arbitrary return values using the pickle mechanism + + - cache_dict to cache dict return values directly using the document db + + - cache_numpy to cache [single] numpy ndarray rvals in a way that supports memmapping of + large arrays. + + Authors are encouraged to use these when they apply, but should feel free to implement + other cache logic when these standard ones are lacking using the CtorlObj.get() and + CtorlObj.set() methods. + + + """ + + def __init__(self, rootdir, db, autosync): + self.db = db + self.r_lookup = {} + self.autosync=autosync + + def get(self, key, default_val=UNSPECIFIED, copy=True): + # Default to return a COPY because a self.set() is required to make a change persistent. + # Inplace changes that the CtrlObj does not know about (via self.set()) will not be saved. + try: + val = self.db[key] + except: + if default_val is not UNSPECIFIED: + # return default_val, but do not add it to the r_lookup object + # since looking up that key in the future would not retrieve default_val + return default_val + else: + raise + if copy: + rval = copy_module.deepcopy(val) + else: + rval = val + self.r_lookup[id(rval)] = key + return rval + + def get_key(self, val): + """Return the key that retrieved `val`. + + This is useful for specifying cache keys for unhashable (e.g. numpy) objects that + happen to be stored in the db. + """ + return self.r_lookup[id(val)] + def set(self, key, val): + vv = dict(val) + if self.db.get(key, None) not in (val, None): + del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key] + for k in del_keys: + del self.r_lookup[k] + self.db[key] = vv + def delete(self, key): + del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key] + for k in del_keys: + del self.r_lookup[k] + del self.db[key] + def checkpoint(self): + """Potentially pass control to another greenlet/tasklet that could potentially + serialize this (calling) greenlet/tasklet using cPickle. + """ + pass + + def sync(self, pull=True, push=True): + """Synchronise local changes with a master version (if applicable). + """ + pass + + def open(self, filename): + """Return a file-handle to a file that can be synced with a server""" + #todo - save references / proxies of the file objects returned here + # and sync them with a server if they are closed + return open(os.path.join(self.rootdir, filename)) + + def open_unique(self, mode='wb', prefix='uniq_', suffix=''): + #TODO: use the standard lib algo for this if you can find it. + if suffix: + template = prefix+'%06i.'+suffix + else: + template = prefix+'%06i' + while True: + fname = template%numpy.random.randint(999999) + path = os.path.join(self.rootdir, fname) + try: + open(path).close() + except IOError: #file not found + return open(path, mode=mode), path + +def memory_ctrl_obj(): + return CtrlObj(db=dict()) + +def directory_ctrl_obj(path, **kwargs): + raise NotImplementedError() + +def mongo_ctrl_obj(connection_args, **kwargs): + raise NotImplementedError() + +def couchdb_ctrl_obj(connection_args, **kwargs): + raise NotImplementedError() + +def jobman_ctrl_obj(connection_args, **kwargs): + raise NotImplementedError() + + +def _default_values(f): + """Return a dictionary param -> default value of function `f`'s parameters""" + default_dict = {} + func_defaults = f.func_defaults + if func_defaults: + first_default_pos = f.func_code.co_argcount-len(f.func_defaults) + params_with_defaults = f.func_code.co_varnames[first_default_pos:f.func_code.co_argcount] + rval = dict(zip(params_with_defaults, f.func_defaults)) + else: + rval = {} + return rval + +def test_default_values(): + + def f(a): pass + assert _default_values(f) == {} + + def f(a, b=1): + aa = 5 + assert _default_values(f) == dict(b=1) + + def f(a, b=1, c=2, *args, **kwargs): + e = b+c + return e + assert _default_values(f) == dict(b=1, c=2) + +def _arg_assignment(f, args, kwargs): + # make a dictionary from args and kwargs that contains all the arguments to f and their + # values + assignment = dict() + + params = f.func_code.co_varnames[:f.func_code.co_argcount] #CORRECT? + + f_accepts_varargs = f.func_code.co_flags & CO_VARARGS + f_accepts_kwargs = f.func_code.co_flags & CO_VARKEYWORDS + + if f_accepts_varargs: + raise NotImplementedError() + if f_accepts_kwargs: + raise NotImplementedError() + + # first add positional arguments + #TODO: what if f accepts a '*args' or similar? + assert len(args) <= f.func_code.co_argcount + for i, a in enumerate(args): + assignment[f.func_code.co_varnames[i]] = a # CORRECT?? + + # next add kw arguments + for k,v in kwargs.iteritems(): + if k in assignment: + #TODO: match Python error + raise TypeError('duplicate argument provided for parameter', k) + + if (not f_accepts_kwargs) and (k not in params): + #TODO: match Python error + raise TypeError('invalid keyword argument', k) + + assignment[k] = v + + # finally add default arguments for any remaining parameters + for k,v in _default_values(f).iteritems(): + if k in assignment: + pass # this argument has [already] been specified + else: + assignment[k] = v + + # TODO + # check that the assignment covers all parameters without default values + + # TODO + # check that the assignment includes no extra variables if f does not accept a '**' + # parameter. + + return assignment + +def test_arg_assignment(): + #TODO: check cases that should cause errors + # - doubly-specified arguments, + # - insufficient arguments + + def f():pass + assert _arg_assignment(f, (), {}) == {} + def f(a):pass + assert _arg_assignment(f, (1,), {}) == {'a':1} + def f(a):pass + assert _arg_assignment(f, (), {'a':1}) == {'a':1} + + def f(a=1):pass + assert _arg_assignment(f, (), {}) == {'a':1} + def f(a=1):pass + assert _arg_assignment(f, (2,), {}) == {'a':2} + def f(a=1):pass + assert _arg_assignment(f, (), {'a':2}) == {'a':2} + def f(a=1):pass + assert _arg_assignment(f, (), {'a':2}) == {'a':2} + + def f(b, a=1): pass + assert _arg_assignment(f, (3,4), {}) == {'b':3, 'a':4} + def f(b, a=1): pass + assert _arg_assignment(f, (3,), {'a':4}) == {'b':3, 'a':4} + def f(b, a=1): pass + assert _arg_assignment(f, (), {'b':3,'a':4}) == {'b':3, 'a':4} + def f(b, a=1): pass + assert _arg_assignment(f, (), {'b':3}) == {'b':3, 'a':1} + def f(b, a=1): a0=6 + assert _arg_assignment(f, (2,), {}) == {'b':2, 'a':1} + +if 0: + def test_arg_assignment_w_varargs(): + def f(b, c=1, *a, **kw): z=5 + assert _arg_assignment(f, (3,), {}) == {'b':3, 'c':1, 'a':(), 'kw':{}} + + +class CtrlObjCacheWrapper(object): + + @classmethod + def decorate(cls, *args, **kwargs): + self = cls(*args, **kwargs) + def rval(f): + self.f = f + return rval + def parse_args(self, args, kwargs): + """Return key, f_args, f_kwargs, by removing ctrl-cache related flags. + + The key is None or a hashable pair that identifies all the arguments to the function. + """ + ctrl_args = dict( + ctrl = None, + ctrl_ignore_cache=False, + ctrl_force_shallow_recompute=False, + ctrl_force_deep_recompute=False, + ) + + # remove the ctrl and ctrl_* arguments + # because they are not meant to be passed to 'f' + ctrl_kwds = [(k,v) for (k,v) in kwargs.iteritems() + if k.startswith('ctrl')] + ctrl_args.update(dict(ctrl_kwds)) + f_kwds = [(k,v) for (k,v) in kwargs.iteritems() + if not k.startswith('ctrl')] + + # assignment is a dictionary with a complete specification of the effective arguments to f + # including default values, varargs, and varkwargs. + assignment = _arg_assignment(self.f, args, dict(f_kwds)) + + assignment_items = assignment.items() + assignment_items.sort() #canonical ordering for parameters + + # replace argument values with explicitly provided keys + assignment_key = [(k, kwargs.get('ctrl_key_%s'%k, v)) + for (k,v) in assignment_items] + + rval_key = ('fn_cache', self.f, tuple(assignment_key)) + try: + hash(rval_key) + except: + rval_key = None + return rval_key, assignment, {}, ctrl_args + + def __doc__(self): + #TODO: Add documentation from self.f + return """ + Optional magic kwargs: + ctrl - use this handle for cache/checkpointing + ctrl_key_%(paramname)s - specify a key to use for a cache lookup of this parameter + ctrl_ignore_cache - completely ignore the cache (but checkpointing can still work) + ctrl_force_shallow_recompute - refresh the cache (but not of sub-calls) + ctrl_force_deep_recompute - recursively refresh the cache + ctrl_nocopy - skip the usual copy of a cached return value + """ + def __call__(self, *args, **kwargs): + # N.B. + # ctrl_force_deep_recompute + # can work by inspecting the call stack + # if any parent frame has a special variable set (e.g. _child_frame_ctrl_force_deep_recompute) + # then it means this is a ctrl_force_deep_recompute too. + key, f_args, f_kwargs, ctrl_args = self.parse_args(args, kwargs) + + ctrl = ctrl_args['ctrl'] + if ctrl is None or ctrl_args['ctrl_ignore_cache']: + return self.f(*f_args, **f_kwargs) + if key: + try: + return self.get_cached_val(ctrl, key) + except KeyError: + pass + f_rval = self.f(*f_args, **f_kwargs) + if key: + f_rval = self.cache_val(ctrl, key, f_rval) + return f_rval + + def get_cached_val(self, ctrl, key): + return ctrl.get(key) + def cache_val(self, ctrl, key, val): + ctrl.set(key, val) + return val + +class NumpyCacheCtrl(CtrlObjCacheWrapper): + def get_cached_val(self, ctrl, key): + filename = ctrl.get(key)['npy_filename'] + return numpy.load(filename) + def cache_val(self, ctrl, key, val): + try: + filename = ctrl.get(key) + except KeyError: + handle, filename = ctrl.open_uniq() + handle.close() + ctrl.set(key, dict(npy_filename=filename)) + numpy.save(filename, val) + return val + +class PickleCacheCtrl(CtrlObjCacheWrapper): + def __init__(self, protocol=0, **kwargs): + self.protocol=protocol + super(PickleCacheCtrl, self).__init__(**kwargs) + def get_cached_val(self, ctrl, key): + return cPickle.loads(ctrl.get(key)['cPickle_str']) + def cache_val(self, ctrl, key, val): + ctrl.set(key, dict(cPickle_str=cPickle.dumps(val))) + return val + +@NumpyCacheCtrl.decorate() +def get_raw_data(rows, cols, seed=67273): + return numpy.random.RandomState(seed).randn(rows, cols) + +@NumpyCacheCtrl.decorate() +def get_whitened_dataset(X, pca, max_components=5): + return X[:,:max_components] + +@PickleCacheCtrl.decorate(protocol=-1) +def get_pca(X, max_components=100): + return dict( + mean=0, + eigvals=numpy.ones(X.shape[1]), + eigvecs=numpy.identity(X.shape[1]) + ) + +@PickleCacheCtrl.decorate(protocol=-1) +def train_mean_var_model(data, ctrl): + mean = numpy.zeros(data.shape[1]) + meansq = numpy.zeros(data.shape[1]) + for i in xrange(data.shape[0]): + alpha = 1.0 / (i+1) + mean += (1-alpha) * mean + data[i] * alpha + meansq += (1-alpha) * meansq + (data[i]**2) * alpha + ctrl.checkpoint() + return (mean, meansq) + +def test_run_experiment(): + + # Could use db, or filesystem, or both, etc. + # There would be generic ones, but the experimenter should be very aware of what is being + # cached where, when, and how. This is how results are stored and retrieved after all. + # Cluster-friendly jobs should not use local files directly, but should store cached + # computations and results to such a database. + # Different jobs should avoid using the same keys in the database because coordinating + # writes is difficult, and conflicts will inevitably arise. + ctrl = memory_ctrl_obj() + + raw_data = get_raw_data(ctrl=ctrl) + raw_data_key = ctrl.get_key(raw_data) + + pca = get_pca(raw_data, max_components=30, ctrl=ctrl) + whitened_data = get_whitened_dataset(raw_data, pca, ctrl=ctrl) + + mean_var = train_mean_var_model( + data=whitened_data+66, + ctrl=ctrl, + ctrl_key_data=whitened_data) #tell that the temporary is tied to whitened_data + + mean, var = mean_var + + #TODO: Test that the cache actually worked!! + +