Mercurial > pylearn
changeset 542:ee5324c21e60
changes to dbdict to use dict-like instead of object-like state
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 17 Nov 2008 18:38:32 -0500 |
parents | 85d3300c9a9c |
children | 34aba0efa3e9 |
files | pylearn/algorithms/rbm.py pylearn/dbdict/dbdict_run.py pylearn/dbdict/experiment.py pylearn/dbdict/tests/test_experiment.py pylearn/dbdict/tools.py pylearn/external/wrap_libsvm.py |
diffstat | 6 files changed, 285 insertions(+), 118 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/algorithms/rbm.py Thu Nov 13 17:54:56 2008 -0500 +++ b/pylearn/algorithms/rbm.py Mon Nov 17 18:38:32 2008 -0500 @@ -12,6 +12,8 @@ from .minimizer import make_minimizer from .stopper import make_stopper +from ..dbdict.experiment import subdict + class RBM(module.FancyModule): # is it really necessary to pass ALL of these ? - GD @@ -51,23 +53,23 @@ pass; def train_rbm(state, channel=lambda *args, **kwargs:None): - dataset = make_dataset(**state.subdict(prefix='dataset_')) + dataset = make_dataset(**subdict_copy(state, prefix='dataset_')) train = dataset.train rbm_module = RBM( nvis=train.x.shape[1], - nhid=state.size_hidden) + nhid=state['size_hidden']) - batchsize = state.batchsize - verbose = state.verbose + batchsize = state['batchsize'] + verbose = state['verbose'] iter = [0] - while iter[0] != state.max_iters: + while iter[0] != state['max_iters']: for j in xrange(0,len(train.x)-batchsize+1,batchsize): rbm.cd1(train.x[j:j+batchsize]) if verbose > 1: print 'estimated train cost...' - if iter[0] == state.max_iters: + if iter[0] == state['max_iters']: break else: iter[0] += 1
--- a/pylearn/dbdict/dbdict_run.py Thu Nov 13 17:54:56 2008 -0500 +++ b/pylearn/dbdict/dbdict_run.py Mon Nov 17 18:38:32 2008 -0500 @@ -1,4 +1,4 @@ -import sys +import sys, signal from .tools import DictProxyState, load_state_fn, run_state # N.B. @@ -42,7 +42,18 @@ self.help() return - run_state(DictProxyState(dct)) + channel_rval = [None] + + def on_sigterm(signo, frame): + channel_rval[0] = 'stop' + + #install a SIGTERM handler that asks the run_state function to return + signal.signal(signal.SIGTERM, on_sigterm) + + def channel(*args, **kwargs): + return channel_rval[0] + + run_state(dct, channel) print dct def help(self):
--- a/pylearn/dbdict/experiment.py Thu Nov 13 17:54:56 2008 -0500 +++ b/pylearn/dbdict/experiment.py Mon Nov 17 18:38:32 2008 -0500 @@ -1,5 +1,7 @@ +"""Helper code for implementing dbdict-compatible jobs""" +import inspect, sys, copy -#These should be +#State values should be instances of these types: INT = type(0) FLT = type(0.0) STR = type('') @@ -7,96 +9,114 @@ COMPLETE = None #jobs can return this by returning nothing as well INCOMPLETE = True #jobs can return this and be restarted - -class Experiment(object): +def subdict(dct, prefix): + """Return the dictionary formed by keys in `dct` that start with the string `prefix`. - new_stdout = 'job_stdout' - new_stderr = 'job_stderr' - - def remap_stdout(self): - """ - Called before start and resume. + In the returned dictionary, the `prefix` is removed from the keynames. + Updates to the sub-dict are reflected in the original dictionary. - Default behaviour is to replace sys.stdout with open(self.new_stdout, 'w+'). - - """ - if self.new_stdout: - sys.stdout = open(self.new_stdout, 'w+') - - def remap_stderr(self): - """ - Called before start and resume. + Example: + a = {'aa':0, 'ab':1, 'bb':2} + s = subdict(a, 'a') # returns dict-like object with keyvals {'a':0, 'b':1} + s['a'] = 5 + s['c'] = 9 + # a == {'aa':5, 'ab':1, 'ac':9, 'bb':2} - Default behaviour is to replace sys.stderr with open(self.new_stderr, 'w+'). - - """ - if self.new_stderr: - sys.stderr = open(self.new_stderr, 'w+') - - def tempdir(self): - """ - Return the recommended filesystem location for temporary files. + """ + class SubDict(object): + def __copy__(s): + rval = {} + rval.update(s) + return rval + def __eq__(s, other): + if len(s) != len(other): + return False + for k in other: + if other[k] != s[k]: + return False + return True + def __len__(s): + return len(s.items()) + def __str__(s): + d = {} + d.update(s) + return str(d) + def keys(s): + return [k[len(prefix):] for k in dct if k.startswith(prefix)] + def values(s): + return [dct[k] for k in dct if k.startswith(prefix)] + def items(s): + return [(k[len(prefix):],dct[k]) for k in dct if k.startswith(prefix)] + def update(s, other): + for k,v in other.items(): + self[k] = v + def __getitem__(s, a): + return dct[prefix+a] + def __setitem__(s, a, v): + dct[prefix+a] = v - The idea is that this will be a fast, local disk partition, suitable - for temporary storage. - - Files here will not generally be available at the time of resume(). + return SubDict() + +def subdict_copy(dct, prefix): + return copy.copy(subdict(dct, prefix)) - The return value of this function may be controlled by one or more - environment variables. - - Will return $DBDICT_EXPERIMENT_TEMPDIR if present. - Failing that, will return $TMP/username-dbdict/hash(self) - Failing that, will return /tmp/username-dbdict/hash(self) +def call_with_kwargs_from_dict(fn, dct, logfile='stderr'): + """Call function `fn` with kwargs taken from dct. - .. note:: - Maybe we should use Python stdlib's tempdir mechanism. - - """ + When fn has a '**' parameter, this function is equivalent to fn(**dct). - print >> sys.stderr, "TODO: get tempdir correctly" - return '/tmp/dbdict-experiment' - + When fn has no '**' parameter, this function removes keys from dct which are not parameter + names of `fn`. The keys which are ignored in this way are logged to the `logfile`. If + logfile is the string 'stdout' or 'stderr', then errors are logged to sys.stdout or + sys.stderr respectively. - def __init__(self, state): - """Called once per lifetime of the class instance. Can be used to - create new jobs and save them to the database. This function will not - be called when a Job is retrieved from the database. - - Parent creates keys: dbdict_id, dbdict_module, dbdict_symbol, dbdict_status. + The reason this function exists is to make it easier to provide default arguments and - """ - - def start(self): - """Called once per lifetime of the compute job. + :param fn: function to call + :param dct: dictionary from which to take arguments of fn + :param logfile: log ignored keys to this file + :type logfile: file-like object or string 'stdout' or string 'stderr' - This is a good place to initialize internal variables. - - After this function returns, either stop() or run() will be called. - - dbdict_status -> RUNNING + :returns: fn(**<something>) - """ - - def resume(self): - """Called to resume computations on a previously stop()'ed job. The - os.getcwd() is just like it was after some previous stop() command. - - This is a good place to load internal variables from os.getcwd(). + """ + argspec = inspect.getargspec(fn) + argnames = argspec[0] + if argspec[2] == None: #if there is no room for a **args type-thing in fn... + kwargs = {} + for k,v in dct.items(): + if k in argnames: + kwargs[k] = v + else: + if not logfile: + pass + elif logfile == 'stderr': + print >> sys.stderr, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v + elif logfile == 'stdout': + print >> sys.stdout, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v + else: + print >> logfile, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v + return fn(**kwargs) + else: + #there is a **args type thing in fn. Here we pass everything. + return fn(**dct) - dbdict_status -> RUNNING - - """ - return self.start() +#MAKE YOUR OWN DBDICT-COMPATIBLE EXPERIMENTS IN THIS MODEL +def sample_experiment(state, channel): + + #read from the state to obtain parameters, configuration, etc. + print >> sys.stdout, state.items() - def run(self, channel): - """Called after start() or resume(). - - channel() may return different things at different times. - None - run should continue. - 'stop' - the job should save state as soon as possible because - the process may soon be terminated + import time + for i in xrange(100): + time.sleep(1) + # use the channel to know if the job should stop ASAP + if channel() == 'stop': + break - When this function returns, dbdict_status -> DONE. - """ + # modify state to record results + state['answer'] = 42 + #return either INCOMPLETE or COMPLETE to indicate that the job should be re-run or not. + return COMPLETE +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dbdict/tests/test_experiment.py Mon Nov 17 18:38:32 2008 -0500 @@ -0,0 +1,66 @@ +from pylearn.dbdict.experiment import * +from unittest import TestCase + +import StringIO + + +class T_subdict(TestCase): + + def test0(self): + a = {'aa':0, 'ab':1, 'bb':2} + s = subdict(a, 'a') # returns dict-like object with keyvals {'a':0, 'b':1} + s['a'] = 5 + s['c'] = 9 + + self.failUnless(s['c'] == 9) + self.failUnless(a['ac'] == 9) + + #check that the subview has the right stuff + sitems = s.items() + sitems.sort() + self.failUnless(sitems == [('a', 5), ('b', 1), ('c', 9)], str(sitems)) + self.failUnless(a['bb'] == 2) + + #add to the subview via the parent + a['az'] = -1 + self.failUnless(s['z'] == -1) + + def test1(self): + a = {'aa':0, 'ab':1, 'bb':2} + + s = subdict(a, 'a') + + r = {} + r.update(s) + + self.failUnless(len(r) == len(s)) + self.failUnless(r == s, (str(r), str(s))) + + +class T_call_with_kwargs_from_dict(TestCase): + + def test0(self): + + def f(a, c=5): + return a+c + + def g(a, **dct): + return a + dct['c'] + + + kwargs = dict(a=1, b=2, c=3) + + io = StringIO.StringIO() + + self.failUnless(call_with_kwargs_from_dict(f, kwargs, logfile=io) == 4) + self.failUnless(io.getvalue() == \ + "WARNING: DictProxyState.call_substate ignoring key-value pair: b 2\n") + self.failUnless(call_with_kwargs_from_dict(g, kwargs, logfile=io) == 4) + self.failUnless(io.getvalue() == \ + "WARNING: DictProxyState.call_substate ignoring key-value pair: b 2\n") + del kwargs['c'] + self.failUnless(call_with_kwargs_from_dict(f, kwargs, logfile=io) == 6) + self.failUnless(io.getvalue() == \ + ("WARNING: DictProxyState.call_substate ignoring key-value pair: b 2\n" + "WARNING: DictProxyState.call_substate ignoring key-value pair: b 2\n")) +
--- a/pylearn/dbdict/tools.py Thu Nov 13 17:54:56 2008 -0500 +++ b/pylearn/dbdict/tools.py Mon Nov 17 18:38:32 2008 -0500 @@ -1,6 +1,8 @@ -import sys +"""Helper code for dbdict job drivers.""" -from .experiment import COMPLETE, INCOMPLETE +import sys, inspect + +from .experiment import COMPLETE, INCOMPLETE, subdict MODULE = 'dbdict_module' SYMBOL = 'dbdict_symbol' @@ -13,18 +15,38 @@ #this proxy object lets experiments use a dict like a state object # def DictProxyState(dct): + """Convenient dict -> object interface for the state parameters of dbdict jobs. + + In the dbdict job running protocol, the user provides a job as a function with two + arguments: + + def myjob(state, channel): + a = getattr(state, 'a', blah) + b = state.blah + ... + + In the case that the caller of myjob has the attributes of this `state` in a dictionary, + then this `DictProxyState` function returns an appropriate object, whose attributes are + backed by this dictionary. + + """ defaults_obj = [None] class Proxy(object): - def subdict(s, prefix=''): - rval = {} - for k,v in dct.items(): - if k.startswith(prefix): - rval[k[len(prefix):]] = v - return rval + def substate(s, prefix=''): + return DictProxyState(subdict(dct, prefix)) + def use_defaults(s, obj): + """Use `obj` to retrieve values when they are not in the `dict`. + + :param obj: a dictionary of default values. + """ defaults_obj[0] = obj def __getitem__(s,a): + """Returns key `a` from the underlying dict, or from the defaults. + + Raises `KeyError` on failure. + """ try: return dct[a] except Exception, e: @@ -34,9 +56,14 @@ raise e def __setitem__(s,a,v): + """Sets key `a` equal to `v` in the underlying dict. """ dct[a] = v def __getattr__(s,a): + """Returns value of key `a` from the underlying dict first, then from the defaults. + + Raises `AttributeError` on failure. + """ try: return dct[a] except KeyError: @@ -50,10 +77,10 @@ # # load the experiment class # - dbdict_module_name = getattr(state,MODULE) - dbdict_symbol = getattr(state, SYMBOL) + dbdict_module_name = state[MODULE] + dbdict_symbol = state[SYMBOL] - preimport_list = getattr(state, PREIMPORT, "").split() + preimport_list = state.get(PREIMPORT, "").split() for preimport in preimport_list: __import__(preimport, fromlist=[None], level=0) @@ -75,3 +102,42 @@ print >> sys.stderr, "WARNING: INVALID job function return value" return rval +def call_with_kwargs_from_dict(fn, dct, logfile='stderr'): + """Call function `fn` with kwargs taken from dct. + + When fn has a '**' parameter, this function is equivalent to fn(**dct). + + When fn has no '**' parameter, this function removes keys from dct which are not parameter + names of `fn`. The keys which are ignored in this way are logged to the `logfile`. If + logfile is the string 'stdout' or 'stderr', then errors are logged to sys.stdout or + sys.stderr respectively. + + :param fn: function to call + :param dct: dictionary from which to take arguments of fn + :param logfile: log ignored keys to this file + :type logfile: file-like object or string 'stdout' or string 'stderr' + + :returns: fn(**<something>) + + """ + argspec = inspect.getargspec(fn) + argnames = argspec[0] + if argspec[2] == None: #if there is no room for a **args type-thing in fn... + kwargs = {} + for k,v in dct.items(): + if k in argnames: + kwargs[k] = v + else: + if not logfile: + pass + elif logfile == 'stderr': + print >> sys.stderr, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v + elif logfile == 'stdout': + print >> sys.stdout, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v + else: + print >> logfile, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v + return fn(**kwargs) + else: + #there is a **args type thing in fn. Here we pass everything. + return fn(**s.subdict(prefix)) +
--- a/pylearn/external/wrap_libsvm.py Thu Nov 13 17:54:56 2008 -0500 +++ b/pylearn/external/wrap_libsvm.py Mon Nov 17 18:38:32 2008 -0500 @@ -1,7 +1,8 @@ """Run an experiment using libsvm. """ import numpy -from ..datasets import dataset_from_descr +from ..datasets import make_dataset +from ..dbdict.experiment import subdict_copy # libsvm currently has no python installation instructions/convention. # @@ -54,30 +55,32 @@ This is the kind of function that dbdict-run can use. """ - ((train_x, train_y), (valid_x, valid_y), (test_x, test_y)) = dataset_from_descr(state.dataset) + dataset = make_dataset(**subdict_copy(state, 'dataset_')) + + #libsvm needs stuff in int32 on a 32bit machine #TODO: test this on a 64bit machine - train_y = numpy.asarray(train_y, dtype='int32') - valid_y = numpy.asarray(valid_y, dtype='int32') - test_y = numpy.asarray(test_y, dtype='int32') - problem = svm.svm_problem(train_y, train_x); + train_y = numpy.asarray(dataset.train.y, dtype='int32') + valid_y = numpy.asarray(dataset.valid.y, dtype='int32') + test_y = numpy.asarray(dataset.test.y, dtype='int32') + problem = libsvm.svm_problem(train_y, dataset.train.x); - gamma0 = 0.5 / numpy.sum(numpy.var(train_x, axis=0)) + gamma0 = 0.5 / numpy.sum(numpy.var(dataset.train.x, axis=0)) - param = svm.svm_parameter(C=state.C, - kernel_type=getattr(svm, state.kernel), - gamma=state.rel_gamma * gamma0) + param = libsvm.svm_parameter(C=state['C'], + kernel_type=getattr(libsvm, state['kernel']), + gamma=state['rel_gamma'] * gamma0) - model = svm.svm_model(problem, param) #this is the expensive part + model = libsvm.svm_model(problem, param) #this is the expensive part - state.train_01 = score_01(train_x, train_y, model) - state.valid_01 = score_01(valid_x, valid_y, model) - state.test_01 = score_01(test_x, test_y, model) + state['train_01'] = score_01(dataset.train.x, train_y, model) + state['valid_01'] = score_01(dataset.valid.x, valid_y, model) + state['test_01'] = score_01(dataset.test.x, test_y, model) - state.n_train = len(train_y) - state.n_valid = len(valid_y) - state.n_test = len(test_y) + state['n_train'] = len(train_y) + state['n_valid'] = len(valid_y) + state['n_test'] = len(test_y) def run_svm_experiment(**kwargs): """Python-friendly interface to dbdict_run_svm_experiment @@ -93,7 +96,6 @@ # 0.14, 0.10 #.. or something... """ - state = State(**kwargs) - state_run_svm_experiment(state) - return state + state_run_svm_experiment(state=kwargs) + return kwargs