view doc/v2_planning/arch_src/checkpoint_JB.py @ 1336:09ad2a4f663c

adding new idea to arch_src
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 18 Oct 2010 19:31:17 -0400
parents
children
line wrap: on
line source

import copy as copy_module

#TODO: use logging.info to report cache hits/ misses

CO_VARARGS = 0x0004
CO_VARKEYWORDS = 0x0008

class mem_db(dict):
    # A key->document dictionary.
    # A "document" is itself a dictionary.

    # A "document" can be a small or large object, but it cannot be partially retrieved.

    # This simple data structure is used in pylearn to cache intermediate reults between
    # several process invocations.
    pass

class UNSPECIFIED(object): 
    pass

class CtrlObj(object):
    """
    Job control API.
    
    This interface makes it easier to break a logical program into pieces that can be
    executed by several different processes, either serially or in parallel.


    The base class provides decorators to simplify some common cache patterns:
     - cache_pickle to cache arbitrary return values using the pickle mechanism

     - cache_dict to cache dict return values directly using the document db

     - cache_numpy to cache [single] numpy ndarray rvals in a way that supports memmapping of
       large arrays.

    Authors are encouraged to use these when they apply, but should feel free to implement
    other cache logic when these standard ones are lacking using the CtorlObj.get() and
    CtorlObj.set() methods.


    """

    def __init__(self, rootdir, db, autosync):
        self.db = db
        self.r_lookup = {}
        self.autosync=autosync

    def get(self, key, default_val=UNSPECIFIED, copy=True):
        # Default to return a COPY because a self.set() is required to make a change persistent.
        # Inplace changes that the CtrlObj does not know about (via self.set()) will not be saved.
        try:
            val = self.db[key]
        except:
            if default_val is not UNSPECIFIED:
                # return default_val, but do not add it to the r_lookup object
                # since looking up that key in the future would not retrieve default_val
                return default_val
            else:
                raise
        if copy:
            rval = copy_module.deepcopy(val)
        else:
            rval = val
        self.r_lookup[id(rval)] = key
        return rval

    def get_key(self, val):
        """Return the key that retrieved `val`.
        
        This is useful for specifying cache keys for unhashable (e.g. numpy) objects that
        happen to be stored in the db.
        """
        return self.r_lookup[id(val)]
    def set(self, key, val):
        vv = dict(val)
        if self.db.get(key, None) not in (val, None):
            del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key]
            for k in del_keys:
                del self.r_lookup[k]
        self.db[key] = vv
    def delete(self, key):
        del_keys = [k for (k,v) in self.r_lookup.iteritems() if v == key]
        for k in del_keys:
            del self.r_lookup[k]
        del self.db[key]
    def checkpoint(self):
        """Potentially pass control to another greenlet/tasklet that could potentially
        serialize this (calling) greenlet/tasklet using cPickle.
        """
        pass

    def sync(self, pull=True, push=True):
        """Synchronise local changes with a master version (if applicable).
        """
        pass

    def open(self, filename):
        """Return a file-handle to a file that can be synced with a server"""
        #todo - save references / proxies of the file objects returned here
        # and sync them with a server if they are closed
        return open(os.path.join(self.rootdir, filename))

    def open_unique(self, mode='wb', prefix='uniq_', suffix=''):
        #TODO: use the standard lib algo for this if you can find it.
        if suffix:
            template = prefix+'%06i.'+suffix
        else:
            template = prefix+'%06i'
        while True:
            fname = template%numpy.random.randint(999999)
            path = os.path.join(self.rootdir, fname)
            try:
                open(path).close()
            except IOError: #file not found
                return open(path, mode=mode), path

def memory_ctrl_obj():
    return CtrlObj(db=dict())

def directory_ctrl_obj(path, **kwargs):
    raise NotImplementedError()

def mongo_ctrl_obj(connection_args, **kwargs):
    raise NotImplementedError()

def couchdb_ctrl_obj(connection_args, **kwargs):
    raise NotImplementedError()

def jobman_ctrl_obj(connection_args, **kwargs):
    raise NotImplementedError()


def _default_values(f):
    """Return a dictionary param -> default value of function `f`'s parameters"""
    default_dict = {}
    func_defaults = f.func_defaults
    if func_defaults:
        first_default_pos = f.func_code.co_argcount-len(f.func_defaults)
        params_with_defaults = f.func_code.co_varnames[first_default_pos:f.func_code.co_argcount]
        rval = dict(zip(params_with_defaults, f.func_defaults))
    else:
        rval = {}
    return rval

def test_default_values():

    def f(a): pass
    assert _default_values(f) == {}

    def f(a, b=1): 
        aa = 5
    assert _default_values(f) == dict(b=1)

    def f(a, b=1, c=2, *args, **kwargs):
        e = b+c
        return e
    assert _default_values(f) == dict(b=1, c=2)

def _arg_assignment(f, args, kwargs):
    # make a dictionary from args and kwargs that contains all the arguments to f and their
    # values
    assignment = dict()

    params = f.func_code.co_varnames[:f.func_code.co_argcount] #CORRECT?

    f_accepts_varargs = f.func_code.co_flags & CO_VARARGS
    f_accepts_kwargs = f.func_code.co_flags & CO_VARKEYWORDS

    if f_accepts_varargs:
        raise NotImplementedError()
    if f_accepts_kwargs:
        raise NotImplementedError()

    # first add positional arguments 
    #TODO: what if f accepts a '*args' or similar?
    assert len(args) <= f.func_code.co_argcount
    for i, a in enumerate(args):
        assignment[f.func_code.co_varnames[i]] = a # CORRECT??

    # next add kw arguments 
    for k,v in kwargs.iteritems():
        if k in assignment:
            #TODO: match Python error
            raise TypeError('duplicate argument provided for parameter', k)

        if (not f_accepts_kwargs) and (k not in params):
            #TODO: match Python error
            raise TypeError('invalid keyword argument', k)

        assignment[k] = v

    # finally add default arguments for any remaining parameters
    for k,v in _default_values(f).iteritems():
        if k in assignment:
            pass # this argument has [already] been specified
        else:
            assignment[k] = v

    # TODO
    # check that the assignment covers all parameters without default values

    # TODO
    # check that the assignment includes no extra variables if f does not accept a '**'
    # parameter.

    return assignment

def test_arg_assignment():
    #TODO: check cases that should cause errors 
    # - doubly-specified arguments, 
    # - insufficient arguments

    def f():pass
    assert _arg_assignment(f, (), {}) == {}
    def f(a):pass
    assert _arg_assignment(f, (1,), {}) == {'a':1}
    def f(a):pass
    assert _arg_assignment(f, (), {'a':1}) == {'a':1}

    def f(a=1):pass
    assert _arg_assignment(f, (), {}) == {'a':1}
    def f(a=1):pass
    assert _arg_assignment(f, (2,), {}) == {'a':2}
    def f(a=1):pass
    assert _arg_assignment(f, (), {'a':2}) == {'a':2}
    def f(a=1):pass
    assert _arg_assignment(f, (), {'a':2}) == {'a':2}

    def f(b, a=1): pass
    assert _arg_assignment(f, (3,4), {}) == {'b':3, 'a':4}
    def f(b, a=1): pass
    assert _arg_assignment(f, (3,), {'a':4}) == {'b':3, 'a':4}
    def f(b, a=1): pass
    assert _arg_assignment(f, (), {'b':3,'a':4}) == {'b':3, 'a':4}
    def f(b, a=1): pass
    assert _arg_assignment(f, (), {'b':3}) == {'b':3, 'a':1}
    def f(b, a=1): a0=6
    assert _arg_assignment(f, (2,), {}) == {'b':2, 'a':1}

if 0:
    def test_arg_assignment_w_varargs():
        def f(b, c=1, *a, **kw): z=5
        assert _arg_assignment(f, (3,), {}) == {'b':3, 'c':1, 'a':(), 'kw':{}}


class CtrlObjCacheWrapper(object):

    @classmethod
    def decorate(cls, *args, **kwargs):
        self = cls(*args, **kwargs)
        def rval(f):
            self.f = f
        return rval
    def parse_args(self, args, kwargs):
        """Return key, f_args, f_kwargs, by removing ctrl-cache related flags.

        The key is None or a hashable pair that identifies all the arguments to the function.
        """
        ctrl_args = dict(
                ctrl = None,
                ctrl_ignore_cache=False,
                ctrl_force_shallow_recompute=False,
                ctrl_force_deep_recompute=False,
                )

        # remove the ctrl and ctrl_* arguments
        # because they are not meant to be passed to 'f'
        ctrl_kwds = [(k,v) for (k,v) in kwargs.iteritems()
                if k.startswith('ctrl')]
        ctrl_args.update(dict(ctrl_kwds))
        f_kwds = [(k,v) for (k,v) in kwargs.iteritems()
                if not k.startswith('ctrl')]

        # assignment is a dictionary with a complete specification of the effective arguments to f
        # including default values, varargs, and varkwargs.
        assignment = _arg_assignment(self.f, args, dict(f_kwds))

        assignment_items = assignment.items()
        assignment_items.sort() #canonical ordering for parameters

        # replace argument values with explicitly provided keys
        assignment_key = [(k, kwargs.get('ctrl_key_%s'%k, v))
                for (k,v) in assignment_items]

        rval_key = ('fn_cache', self.f, tuple(assignment_key))
        try:
            hash(rval_key)
        except:
            rval_key = None
        return rval_key, assignment, {}, ctrl_args

    def __doc__(self):
        #TODO: Add documentation from self.f
        return """
        Optional magic kwargs:
        ctrl                   - use this handle for cache/checkpointing
        ctrl_key_%(paramname)s - specify a key to use for a cache lookup of this parameter
        ctrl_ignore_cache      - completely ignore the cache (but checkpointing can still work)
        ctrl_force_shallow_recompute - refresh the cache (but not of sub-calls)
        ctrl_force_deep_recompute - recursively refresh the cache
        ctrl_nocopy            - skip the usual copy of a cached return value
        """
    def __call__(self, *args, **kwargs):
        # N.B.
        #  ctrl_force_deep_recompute
        #  can work by inspecting the call stack
        #  if any parent frame has a special variable set (e.g. _child_frame_ctrl_force_deep_recompute)
        #  then it means this is a ctrl_force_deep_recompute too.
        key, f_args, f_kwargs, ctrl_args = self.parse_args(args, kwargs)

        ctrl = ctrl_args['ctrl']
        if ctrl is None or ctrl_args['ctrl_ignore_cache']:
            return self.f(*f_args, **f_kwargs)
        if key:
            try:
                return self.get_cached_val(ctrl, key)
            except KeyError:
                pass
        f_rval = self.f(*f_args, **f_kwargs)
        if key:
            f_rval = self.cache_val(ctrl, key, f_rval)
        return f_rval

    def get_cached_val(self, ctrl, key):
        return ctrl.get(key)
    def cache_val(self, ctrl, key, val):
        ctrl.set(key, val)
        return val

class NumpyCacheCtrl(CtrlObjCacheWrapper):
    def get_cached_val(self, ctrl, key):
        filename = ctrl.get(key)['npy_filename']
        return numpy.load(filename)
    def cache_val(self, ctrl, key, val):
        try:
            filename = ctrl.get(key)
        except KeyError:
            handle, filename = ctrl.open_uniq()
            handle.close()
            ctrl.set(key, dict(npy_filename=filename))
        numpy.save(filename, val)
        return val

class PickleCacheCtrl(CtrlObjCacheWrapper):
    def __init__(self, protocol=0, **kwargs):
        self.protocol=protocol
        super(PickleCacheCtrl, self).__init__(**kwargs)
    def get_cached_val(self, ctrl, key):
        return cPickle.loads(ctrl.get(key)['cPickle_str'])
    def cache_val(self, ctrl, key, val):
        ctrl.set(key, dict(cPickle_str=cPickle.dumps(val)))
        return val

@NumpyCacheCtrl.decorate()
def get_raw_data(rows, cols, seed=67273):
    return numpy.random.RandomState(seed).randn(rows, cols)

@NumpyCacheCtrl.decorate()
def get_whitened_dataset(X, pca, max_components=5):
    return X[:,:max_components]

@PickleCacheCtrl.decorate(protocol=-1)
def get_pca(X, max_components=100):
    return dict(
            mean=0,
            eigvals=numpy.ones(X.shape[1]),
            eigvecs=numpy.identity(X.shape[1])
            )

@PickleCacheCtrl.decorate(protocol=-1)
def train_mean_var_model(data, ctrl):
    mean = numpy.zeros(data.shape[1])
    meansq = numpy.zeros(data.shape[1])
    for i in xrange(data.shape[0]):
        alpha = 1.0 / (i+1)
        mean += (1-alpha) * mean + data[i] * alpha
        meansq += (1-alpha) * meansq + (data[i]**2) * alpha
        ctrl.checkpoint()
    return (mean, meansq)

def test_run_experiment():

    # Could use db, or filesystem, or both, etc.
    # There would be generic ones, but the experimenter should be very aware of what is being
    # cached where, when, and how.  This is how results are stored and retrieved after all.
    # Cluster-friendly jobs should not use local files directly, but should store cached
    # computations and results to such a database.
    #  Different jobs should avoid using the same keys in the database because coordinating
    #  writes is difficult, and conflicts will inevitably arise.
    ctrl = memory_ctrl_obj()

    raw_data = get_raw_data(ctrl=ctrl)
    raw_data_key = ctrl.get_key(raw_data)

    pca = get_pca(raw_data, max_components=30, ctrl=ctrl) 
    whitened_data = get_whitened_dataset(raw_data, pca, ctrl=ctrl)

    mean_var = train_mean_var_model(
            data=whitened_data+66,
            ctrl=ctrl,
            ctrl_key_data=whitened_data) #tell that the temporary is tied to whitened_data

    mean, var = mean_var

    #TODO: Test that the cache actually worked!!