changeset 1336:09ad2a4f663c

adding new idea to arch_src
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 18 Oct 2010 19:31:17 -0400
parents 7c51c0355d86
children 7dfc3d3052ea
files doc/v2_planning/arch_src/checkpoint_JB.py
diffstat 1 files changed, 408 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/doc/v2_planning/arch_src/checkpoint_JB.py	Mon Oct 18 19:31:17 2010 -0400
@@ -0,0 +1,408 @@
+import copy as copy_module
+
+#TODO: use logging.info to report cache hits/ misses
+
+CO_VARARGS = 0x0004
+CO_VARKEYWORDS = 0x0008
+
+class mem_db(dict):
+    # A key->document dictionary.
+    # A "document" is itself a dictionary.
+
+    # A "document" can be a small or large object, but it cannot be partially retrieved.
+
+    # This simple data structure is used in pylearn to cache intermediate reults between
+    # several process invocations.
+    pass
+
+class UNSPECIFIED(object): 
+    pass
+
+class CtrlObj(object):
+    """
+    Job control API.
+    
+    This interface makes it easier to break a logical program into pieces that can be
+    executed by several different processes, either serially or in parallel.
+
+
+    The base class provides decorators to simplify some common cache patterns:
+     - cache_pickle to cache arbitrary return values using the pickle mechanism
+
+     - cache_dict to cache dict return values directly using the document db
+
+     - cache_numpy to cache [single] numpy ndarray rvals in a way that supports memmapping of
+       large arrays.
+
+    Authors are encouraged to use these when they apply, but should feel free to implement
+    other cache logic when these standard ones are lacking using the CtorlObj.get() and
+    CtorlObj.set() methods.
+
+
+    """
+
+    def __init__(self, rootdir, db, autosync):
+        self.db = db
+        self.r_lookup = {}
+        self.autosync=autosync
+
+    def get(self, key, default_val=UNSPECIFIED, copy=True):
+        # Default to return a COPY because a self.set() is required to make a change persistent.
+        # Inplace changes that the CtrlObj does not know about (via self.set()) will not be saved.
+        try:
+            val = self.db[key]
+        except:
+            if default_val is not UNSPECIFIED:
+                # return default_val, but do not add it to the r_lookup object
+                # since looking up that key in the future would not retrieve default_val
+                return default_val
+            else:
+                raise
+        if copy:
+            rval = copy_module.deepcopy(val)
+        else:
+            rval = val
+        self.r_lookup[id(rval)] = key
+        return rval
+
+    def get_key(self, val):
+        """Return the key that retrieved `val`.
+        
+        This is useful for specifying cache keys for unhashable (e.g. numpy) objects that
+        happen to be stored in the db.
+        """
+        return self.r_lookup[id(val)]
+    def set(self, key, val):
+        vv = dict(val)
+        if self.db.get(key, None) not in (val, None):
+            del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key]
+            for k in del_keys:
+                del self.r_lookup[k]
+        self.db[key] = vv
+    def delete(self, key):
+        del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key]
+        for k in del_keys:
+            del self.r_lookup[k]
+        del self.db[key]
+    def checkpoint(self):
+        """Potentially pass control to another greenlet/tasklet that could potentially
+        serialize this (calling) greenlet/tasklet using cPickle.
+        """
+        pass
+
+    def sync(self, pull=True, push=True):
+        """Synchronise local changes with a master version (if applicable).
+        """
+        pass
+
+    def open(self, filename):
+        """Return a file-handle to a file that can be synced with a server"""
+        #todo - save references / proxies of the file objects returned here
+        # and sync them with a server if they are closed
+        return open(os.path.join(self.rootdir, filename))
+
+    def open_unique(self, mode='wb', prefix='uniq_', suffix=''):
+        #TODO: use the standard lib algo for this if you can find it.
+        if suffix:
+            template = prefix+'%06i.'+suffix
+        else:
+            template = prefix+'%06i'
+        while True:
+            fname = template%numpy.random.randint(999999)
+            path = os.path.join(self.rootdir, fname)
+            try:
+                open(path).close()
+            except IOError: #file not found
+                return open(path, mode=mode), path
+
+def memory_ctrl_obj():
+    return CtrlObj(db=dict())
+
+def directory_ctrl_obj(path, **kwargs):
+    raise NotImplementedError()
+
+def mongo_ctrl_obj(connection_args, **kwargs):
+    raise NotImplementedError()
+
+def couchdb_ctrl_obj(connection_args, **kwargs):
+    raise NotImplementedError()
+
+def jobman_ctrl_obj(connection_args, **kwargs):
+    raise NotImplementedError()
+
+
+def _default_values(f):
+    """Return a dictionary param -> default value of function `f`'s parameters"""
+    default_dict = {}
+    func_defaults = f.func_defaults
+    if func_defaults:
+        first_default_pos = f.func_code.co_argcount-len(f.func_defaults)
+        params_with_defaults = f.func_code.co_varnames[first_default_pos:f.func_code.co_argcount]
+        rval = dict(zip(params_with_defaults, f.func_defaults))
+    else:
+        rval = {}
+    return rval
+
+def test_default_values():
+
+    def f(a): pass
+    assert _default_values(f) == {}
+
+    def f(a, b=1): 
+        aa = 5
+    assert _default_values(f) == dict(b=1)
+
+    def f(a, b=1, c=2, *args, **kwargs):
+        e = b+c
+        return e
+    assert _default_values(f) == dict(b=1, c=2)
+
+def _arg_assignment(f, args, kwargs):
+    # make a dictionary from args and kwargs that contains all the arguments to f and their
+    # values
+    assignment = dict()
+
+    params = f.func_code.co_varnames[:f.func_code.co_argcount] #CORRECT?
+
+    f_accepts_varargs = f.func_code.co_flags & CO_VARARGS
+    f_accepts_kwargs = f.func_code.co_flags & CO_VARKEYWORDS
+
+    if f_accepts_varargs:
+        raise NotImplementedError()
+    if f_accepts_kwargs:
+        raise NotImplementedError()
+
+    # first add positional arguments 
+    #TODO: what if f accepts a '*args' or similar?
+    assert len(args) <= f.func_code.co_argcount
+    for i, a in enumerate(args):
+        assignment[f.func_code.co_varnames[i]] = a # CORRECT??
+
+    # next add kw arguments 
+    for k,v in kwargs.iteritems():
+        if k in assignment:
+            #TODO: match Python error
+            raise TypeError('duplicate argument provided for parameter', k)
+
+        if (not f_accepts_kwargs) and (k not in params):
+            #TODO: match Python error
+            raise TypeError('invalid keyword argument', k)
+
+        assignment[k] = v
+
+    # finally add default arguments for any remaining parameters
+    for k,v in _default_values(f).iteritems():
+        if k in assignment:
+            pass # this argument has [already] been specified
+        else:
+            assignment[k] = v
+
+    # TODO
+    # check that the assignment covers all parameters without default values
+
+    # TODO
+    # check that the assignment includes no extra variables if f does not accept a '**'
+    # parameter.
+
+    return assignment
+
+def test_arg_assignment():
+    #TODO: check cases that should cause errors 
+    # - doubly-specified arguments, 
+    # - insufficient arguments
+
+    def f():pass
+    assert _arg_assignment(f, (), {}) == {}
+    def f(a):pass
+    assert _arg_assignment(f, (1,), {}) == {'a':1}
+    def f(a):pass
+    assert _arg_assignment(f, (), {'a':1}) == {'a':1}
+
+    def f(a=1):pass
+    assert _arg_assignment(f, (), {}) == {'a':1}
+    def f(a=1):pass
+    assert _arg_assignment(f, (2,), {}) == {'a':2}
+    def f(a=1):pass
+    assert _arg_assignment(f, (), {'a':2}) == {'a':2}
+    def f(a=1):pass
+    assert _arg_assignment(f, (), {'a':2}) == {'a':2}
+
+    def f(b, a=1): pass
+    assert _arg_assignment(f, (3,4), {}) == {'b':3, 'a':4}
+    def f(b, a=1): pass
+    assert _arg_assignment(f, (3,), {'a':4}) == {'b':3, 'a':4}
+    def f(b, a=1): pass
+    assert _arg_assignment(f, (), {'b':3,'a':4}) == {'b':3, 'a':4}
+    def f(b, a=1): pass
+    assert _arg_assignment(f, (), {'b':3}) == {'b':3, 'a':1}
+    def f(b, a=1): a0=6
+    assert _arg_assignment(f, (2,), {}) == {'b':2, 'a':1}
+
+if 0:
+    def test_arg_assignment_w_varargs():
+        def f(b, c=1, *a, **kw): z=5
+        assert _arg_assignment(f, (3,), {}) == {'b':3, 'c':1, 'a':(), 'kw':{}}
+
+
+class CtrlObjCacheWrapper(object):
+
+    @classmethod
+    def decorate(cls, *args, **kwargs):
+        self = cls(*args, **kwargs)
+        def rval(f):
+            self.f = f
+        return rval
+    def parse_args(self, args, kwargs):
+        """Return key, f_args, f_kwargs, by removing ctrl-cache related flags.
+
+        The key is None or a hashable pair that identifies all the arguments to the function.
+        """
+        ctrl_args = dict(
+                ctrl = None,
+                ctrl_ignore_cache=False,
+                ctrl_force_shallow_recompute=False,
+                ctrl_force_deep_recompute=False,
+                )
+
+        # remove the ctrl and ctrl_* arguments
+        # because they are not meant to be passed to 'f'
+        ctrl_kwds = [(k,v) for (k,v) in kwargs.iteritems()
+                if k.startswith('ctrl')]
+        ctrl_args.update(dict(ctrl_kwds))
+        f_kwds = [(k,v) for (k,v) in kwargs.iteritems()
+                if not k.startswith('ctrl')]
+
+        # assignment is a dictionary with a complete specification of the effective arguments to f
+        # including default values, varargs, and varkwargs.
+        assignment = _arg_assignment(self.f, args, dict(f_kwds))
+
+        assignment_items = assignment.items()
+        assignment_items.sort() #canonical ordering for parameters
+
+        # replace argument values with explicitly provided keys
+        assignment_key = [(k, kwargs.get('ctrl_key_%s'%k, v))
+                for (k,v) in assignment_items]
+
+        rval_key = ('fn_cache', self.f, tuple(assignment_key))
+        try:
+            hash(rval_key)
+        except:
+            rval_key = None
+        return rval_key, assignment, {}, ctrl_args
+
+    def __doc__(self):
+        #TODO: Add documentation from self.f
+        return """
+        Optional magic kwargs:
+        ctrl                   - use this handle for cache/checkpointing
+        ctrl_key_%(paramname)s - specify a key to use for a cache lookup of this parameter
+        ctrl_ignore_cache      - completely ignore the cache (but checkpointing can still work)
+        ctrl_force_shallow_recompute - refresh the cache (but not of sub-calls)
+        ctrl_force_deep_recompute - recursively refresh the cache
+        ctrl_nocopy            - skip the usual copy of a cached return value
+        """
+    def __call__(self, *args, **kwargs):
+        # N.B.
+        #  ctrl_force_deep_recompute
+        #  can work by inspecting the call stack
+        #  if any parent frame has a special variable set (e.g. _child_frame_ctrl_force_deep_recompute)
+        #  then it means this is a ctrl_force_deep_recompute too.
+        key, f_args, f_kwargs, ctrl_args = self.parse_args(args, kwargs)
+
+        ctrl = ctrl_args['ctrl']
+        if ctrl is None or ctrl_args['ctrl_ignore_cache']:
+            return self.f(*f_args, **f_kwargs)
+        if key:
+            try:
+                return self.get_cached_val(ctrl, key)
+            except KeyError:
+                pass
+        f_rval = self.f(*f_args, **f_kwargs)
+        if key:
+            f_rval = self.cache_val(ctrl, key, f_rval)
+        return f_rval
+
+    def get_cached_val(self, ctrl, key):
+        return ctrl.get(key)
+    def cache_val(self, ctrl, key, val):
+        ctrl.set(key, val)
+        return val
+
+class NumpyCacheCtrl(CtrlObjCacheWrapper):
+    def get_cached_val(self, ctrl, key):
+        filename = ctrl.get(key)['npy_filename']
+        return numpy.load(filename)
+    def cache_val(self, ctrl, key, val):
+        try:
+            filename = ctrl.get(key)
+        except KeyError:
+            handle, filename = ctrl.open_uniq()
+            handle.close()
+            ctrl.set(key, dict(npy_filename=filename))
+        numpy.save(filename, val)
+        return val
+
+class PickleCacheCtrl(CtrlObjCacheWrapper):
+    def __init__(self, protocol=0, **kwargs):
+        self.protocol=protocol
+        super(PickleCacheCtrl, self).__init__(**kwargs)
+    def get_cached_val(self, ctrl, key):
+        return cPickle.loads(ctrl.get(key)['cPickle_str'])
+    def cache_val(self, ctrl, key, val):
+        ctrl.set(key, dict(cPickle_str=cPickle.dumps(val)))
+        return val
+
+@NumpyCacheCtrl.decorate()
+def get_raw_data(rows, cols, seed=67273):
+    return numpy.random.RandomState(seed).randn(rows, cols)
+
+@NumpyCacheCtrl.decorate()
+def get_whitened_dataset(X, pca, max_components=5):
+    return X[:,:max_components]
+
+@PickleCacheCtrl.decorate(protocol=-1)
+def get_pca(X, max_components=100):
+    return dict(
+            mean=0,
+            eigvals=numpy.ones(X.shape[1]),
+            eigvecs=numpy.identity(X.shape[1])
+            )
+
+@PickleCacheCtrl.decorate(protocol=-1)
+def train_mean_var_model(data, ctrl):
+    mean = numpy.zeros(data.shape[1])
+    meansq = numpy.zeros(data.shape[1])
+    for i in xrange(data.shape[0]):
+        alpha = 1.0 / (i+1)
+        mean += (1-alpha) * mean + data[i] * alpha
+        meansq += (1-alpha) * meansq + (data[i]**2) * alpha
+        ctrl.checkpoint()
+    return (mean, meansq)
+
+def test_run_experiment():
+
+    # Could use db, or filesystem, or both, etc.
+    # There would be generic ones, but the experimenter should be very aware of what is being
+    # cached where, when, and how.  This is how results are stored and retrieved after all.
+    # Cluster-friendly jobs should not use local files directly, but should store cached
+    # computations and results to such a database.
+    #  Different jobs should avoid using the same keys in the database because coordinating
+    #  writes is difficult, and conflicts will inevitably arise.
+    ctrl = memory_ctrl_obj()
+
+    raw_data = get_raw_data(ctrl=ctrl)
+    raw_data_key = ctrl.get_key(raw_data)
+
+    pca = get_pca(raw_data, max_components=30, ctrl=ctrl) 
+    whitened_data = get_whitened_dataset(raw_data, pca, ctrl=ctrl)
+
+    mean_var = train_mean_var_model(
+            data=whitened_data+66,
+            ctrl=ctrl,
+            ctrl_key_data=whitened_data) #tell that the temporary is tied to whitened_data
+
+    mean, var = mean_var
+
+    #TODO: Test that the cache actually worked!!
+
+