changeset 546:cb8eabe7d941

merged
author desjagui@atchoum.iro.umontreal.ca
date Tue, 18 Nov 2008 19:42:12 -0500
parents 24dfe316e79a (current diff) de6de7c2c54b (diff)
children d3791c59f36e 0ac4927e9d97
files
diffstat 6 files changed, 285 insertions(+), 118 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/rbm.py	Tue Nov 18 19:31:56 2008 -0500
+++ b/pylearn/algorithms/rbm.py	Tue Nov 18 19:42:12 2008 -0500
@@ -11,6 +11,7 @@
 from ..datasets import make_dataset
 from .minimizer import make_minimizer
 from .stopper import make_stopper
+from ..dbdict.experiment import subdict
 
 class RBM(T.RModule):
 
@@ -75,24 +76,24 @@
 
 
 def train_rbm(state, channel=lambda *args, **kwargs:None):
-    dataset = make_dataset(**state.subdict(prefix='dataset_'))
+    dataset = make_dataset(**subdict_copy(state, prefix='dataset_'))
     train = dataset.train
 
     rbm_module = RBM(
             nvis=train.x.shape[1],
-            nhid=state.nhid)
+            nhid=state['nhid'])
     rbm = rbm_module.make()
 
-    batchsize = getattr(state,'batchsize', 1)
-    verbose = getattr(state,'verbose',1)
+    batchsize = state.get('batchsize', 1)
+    verbose = state.get('verbose', 1)
     iter = [0]
 
-    while iter[0] != state.max_iters:
+    while iter[0] != state['max_iters']:
         for j in xrange(0,len(train.x)-batchsize+1,batchsize):
             rbm.cd1(train.x[j:j+batchsize])
             if verbose > 1:
                 print 'estimated train cost...'
-            if iter[0] == state.max_iters:
+            if iter[0] == state['max_iters']:
                 break
             else:
                 iter[0] += 1
--- a/pylearn/dbdict/dbdict_run.py	Tue Nov 18 19:31:56 2008 -0500
+++ b/pylearn/dbdict/dbdict_run.py	Tue Nov 18 19:42:12 2008 -0500
@@ -1,4 +1,4 @@
-import sys
+import sys, signal
 from .tools import DictProxyState, load_state_fn, run_state
 
 # N.B.
@@ -42,7 +42,19 @@
             self.help()
             return
 
-        run_state(DictProxyState(dct))
+        channel_rval = [None]
+
+        def on_sigterm(signo, frame):
+            channel_rval[0] = 'stop'
+
+        #install a SIGTERM handler that asks the run_state function to return
+        signal.signal(signal.SIGTERM, on_sigterm)
+        signal.signal(signal.SIGINT, on_sigterm)
+
+        def channel(*args, **kwargs):
+            return channel_rval[0]
+
+        run_state(dct, channel)
         print dct
 
     def help(self):
--- a/pylearn/dbdict/experiment.py	Tue Nov 18 19:31:56 2008 -0500
+++ b/pylearn/dbdict/experiment.py	Tue Nov 18 19:42:12 2008 -0500
@@ -1,5 +1,7 @@
+"""Helper code for implementing dbdict-compatible jobs"""
+import inspect, sys, copy
 
-#These should be
+#State values should be instances of these types:
 INT = type(0)
 FLT = type(0.0)
 STR = type('')
@@ -7,96 +9,114 @@
 COMPLETE = None    #jobs can return this by returning nothing as well
 INCOMPLETE = True  #jobs can return this and be restarted
 
-
-class Experiment(object):
+def subdict(dct, prefix):
+    """Return the dictionary formed by keys in `dct` that start with the string `prefix`.
 
-    new_stdout = 'job_stdout'
-    new_stderr = 'job_stderr'
-
-    def remap_stdout(self):
-        """
-        Called before start and resume.
+    In the returned dictionary, the `prefix` is removed from the keynames.
+    Updates to the sub-dict are reflected in the original dictionary.
 
-        Default behaviour is to replace sys.stdout with open(self.new_stdout, 'w+').
-
-        """
-        if self.new_stdout:
-            sys.stdout = open(self.new_stdout, 'w+')
-
-    def remap_stderr(self):
-        """
-        Called before start and resume.
+    Example:
+        a = {'aa':0, 'ab':1, 'bb':2}
+        s = subdict(a, 'a') # returns dict-like object with keyvals {'a':0, 'b':1}
+        s['a'] = 5
+        s['c'] = 9
+        # a == {'aa':5, 'ab':1, 'ac':9, 'bb':2}
 
-        Default behaviour is to replace sys.stderr with open(self.new_stderr, 'w+').
-        
-        """
-        if self.new_stderr:
-            sys.stderr = open(self.new_stderr, 'w+')
-
-    def tempdir(self):
-        """
-        Return the recommended filesystem location for temporary files.
+    """
+    class SubDict(object):
+        def __copy__(s):
+            rval = {}
+            rval.update(s)
+            return rval
+        def __eq__(s, other):
+            if len(s) != len(other): 
+                return False
+            for k in other:
+                if other[k] != s[k]:
+                    return False
+            return True
+        def __len__(s):
+            return len(s.items())
+        def __str__(s):
+            d = {}
+            d.update(s)
+            return str(d)
+        def keys(s):
+            return [k[len(prefix):] for k in dct if k.startswith(prefix)]
+        def values(s):
+            return [dct[k] for k in dct if k.startswith(prefix)]
+        def items(s):
+            return [(k[len(prefix):],dct[k]) for k in dct if k.startswith(prefix)]
+        def update(s, other):
+            for k,v in other.items():
+                self[k] = v
+        def __getitem__(s, a):
+            return dct[prefix+a]
+        def __setitem__(s, a, v):
+            dct[prefix+a] = v
 
-        The idea is that this will be a fast, local disk partition, suitable
-        for temporary storage.
-        
-        Files here will not generally be available at the time of resume().
+    return SubDict()
+
+def subdict_copy(dct, prefix):
+    return copy.copy(subdict(dct, prefix))
 
-        The return value of this function may be controlled by one or more
-        environment variables.  
-        
-        Will return $DBDICT_EXPERIMENT_TEMPDIR if present.
-        Failing that, will return $TMP/username-dbdict/hash(self)
-        Failing that, will return /tmp/username-dbdict/hash(self)
+def call_with_kwargs_from_dict(fn, dct, logfile='stderr'):
+    """Call function `fn` with kwargs taken from dct.
 
-        .. note::
-            Maybe we should use Python stdlib's tempdir mechanism. 
-
-        """
+    When fn has a '**' parameter, this function is equivalent to fn(**dct).
 
-        print >> sys.stderr, "TODO: get tempdir correctly"
-        return '/tmp/dbdict-experiment'
-
+    When fn has no '**' parameter, this function removes keys from dct which are not parameter
+    names of `fn`.  The keys which are ignored in this way are logged to the `logfile`.  If
+    logfile is the string 'stdout' or 'stderr', then errors are logged to sys.stdout or
+    sys.stderr respectively.
 
-    def __init__(self, state):
-        """Called once per lifetime of the class instance.  Can be used to
-        create new jobs and save them to the database.   This function will not
-        be called when a Job is retrieved from the database.
-
-        Parent creates keys: dbdict_id, dbdict_module, dbdict_symbol, dbdict_status.
+    The reason this function exists is to make it easier to provide default arguments and 
 
-        """
-
-    def start(self):
-        """Called once per lifetime of the compute job.
+    :param fn: function to call
+    :param dct: dictionary from which to take arguments of fn
+    :param logfile: log ignored keys to this file
+    :type logfile: file-like object or string 'stdout' or string 'stderr'
 
-        This is a good place to initialize internal variables.
-
-        After this function returns, either stop() or run() will be called.
-        
-        dbdict_status -> RUNNING
+    :returns: fn(**<something>)
 
-        """
-
-    def resume(self):
-        """Called to resume computations on a previously stop()'ed job.  The
-        os.getcwd() is just like it was after some previous stop() command.
-
-        This is a good place to load internal variables from os.getcwd().
+    """
+    argspec = inspect.getargspec(fn)
+    argnames = argspec[0]
+    if argspec[2] == None: #if there is no room for a **args type-thing in fn...
+        kwargs = {}
+        for k,v in dct.items():
+            if k in argnames:
+                kwargs[k] = v
+            else:
+                if not logfile:
+                    pass
+                elif logfile == 'stderr':
+                    print >> sys.stderr, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v
+                elif logfile == 'stdout':
+                    print >> sys.stdout, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v
+                else:
+                    print >> logfile, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v
+        return fn(**kwargs)
+    else:
+        #there is a **args type thing in fn. Here we pass everything.
+        return fn(**dct)
 
-        dbdict_status -> RUNNING
-        
-        """
-        return self.start()
+#MAKE YOUR OWN DBDICT-COMPATIBLE EXPERIMENTS IN THIS MODEL
+def sample_experiment(state, channel):
+
+    #read from the state to obtain parameters, configuration, etc.
+    print >> sys.stdout, state.items()
 
-    def run(self, channel):
-        """Called after start() or resume().
-        
-        channel() may return different things at different times.  
-            None   - run should continue.
-            'stop' - the job should save state as soon as possible because
-                     the process may soon be terminated
+    import time
+    for i in xrange(100):
+        time.sleep(1)
+        # use the channel to know if the job should stop ASAP
+        if channel() == 'stop':
+            break
 
-        When this function returns, dbdict_status -> DONE.
-        """
+    # modify state to record results
+    state['answer'] = 42
 
+    #return either INCOMPLETE or COMPLETE to indicate that the job should be re-run or not.
+    return COMPLETE
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/dbdict/tests/test_experiment.py	Tue Nov 18 19:42:12 2008 -0500
@@ -0,0 +1,66 @@
+from pylearn.dbdict.experiment import *
+from unittest import TestCase
+
+import StringIO
+
+
+class T_subdict(TestCase):
+
+    def test0(self):
+        a = {'aa':0, 'ab':1, 'bb':2}
+        s = subdict(a, 'a') # returns dict-like object with keyvals {'a':0, 'b':1}
+        s['a'] = 5
+        s['c'] = 9
+
+        self.failUnless(s['c'] == 9)
+        self.failUnless(a['ac'] == 9)
+
+        #check that the subview has the right stuff
+        sitems = s.items()
+        sitems.sort()
+        self.failUnless(sitems == [('a', 5), ('b', 1), ('c', 9)], str(sitems))
+        self.failUnless(a['bb'] == 2)
+
+        #add to the subview via the parent
+        a['az'] = -1
+        self.failUnless(s['z'] == -1)
+
+    def test1(self):
+        a = {'aa':0, 'ab':1, 'bb':2}
+
+        s = subdict(a, 'a')
+
+        r = {}
+        r.update(s)
+
+        self.failUnless(len(r) == len(s))
+        self.failUnless(r == s, (str(r), str(s)))
+
+
+class T_call_with_kwargs_from_dict(TestCase):
+
+    def test0(self):
+
+        def f(a, c=5):
+            return a+c
+
+        def g(a, **dct):
+            return a + dct['c']
+
+
+        kwargs = dict(a=1, b=2, c=3)
+
+        io = StringIO.StringIO()
+
+        self.failUnless(call_with_kwargs_from_dict(f, kwargs, logfile=io) == 4)
+        self.failUnless(io.getvalue() == \
+                "WARNING: DictProxyState.call_substate ignoring key-value pair: b 2\n")
+        self.failUnless(call_with_kwargs_from_dict(g, kwargs, logfile=io) == 4)
+        self.failUnless(io.getvalue() == \
+                "WARNING: DictProxyState.call_substate ignoring key-value pair: b 2\n")
+        del kwargs['c']
+        self.failUnless(call_with_kwargs_from_dict(f, kwargs, logfile=io) == 6)
+        self.failUnless(io.getvalue() ==  \
+                ("WARNING: DictProxyState.call_substate ignoring key-value pair: b 2\n"
+                "WARNING: DictProxyState.call_substate ignoring key-value pair: b 2\n"))
+
--- a/pylearn/dbdict/tools.py	Tue Nov 18 19:31:56 2008 -0500
+++ b/pylearn/dbdict/tools.py	Tue Nov 18 19:42:12 2008 -0500
@@ -1,6 +1,8 @@
-import sys
+"""Helper code for dbdict job drivers."""
 
-from .experiment  import COMPLETE, INCOMPLETE
+import sys, inspect
+
+from .experiment  import COMPLETE, INCOMPLETE, subdict
 
 MODULE = 'dbdict_module'
 SYMBOL = 'dbdict_symbol'
@@ -13,18 +15,38 @@
 #this proxy object lets experiments use a dict like a state object
 #
 def DictProxyState(dct):
+    """Convenient dict -> object interface for the state parameters of dbdict jobs.
+
+    In the dbdict job running protocol, the user provides a job as a function with two
+    arguments:
+        
+        def myjob(state, channel):
+            a = getattr(state, 'a', blah)
+            b = state.blah
+            ...
+
+    In the case that the caller of myjob has the attributes of this `state` in a dictionary,
+    then this `DictProxyState` function returns an appropriate object, whose attributes are
+    backed by this dictionary.
+
+    """
     defaults_obj = [None]
     class Proxy(object):
-        def subdict(s, prefix=''):
-            rval = {}
-            for k,v in dct.items():
-                if k.startswith(prefix):
-                    rval[k[len(prefix):]] = v
-            return rval
+        def substate(s, prefix=''):
+            return DictProxyState(subdict(dct, prefix))
+
         def use_defaults(s, obj):
+            """Use `obj` to retrieve values when they are not in the `dict`.
+
+            :param obj: a dictionary of default values.
+            """
             defaults_obj[0] = obj
 
         def __getitem__(s,a):
+            """Returns key `a` from the underlying dict, or from the defaults.
+            
+            Raises `KeyError` on failure.
+            """
             try:
                 return dct[a]
             except Exception, e:
@@ -34,9 +56,14 @@
                     raise e
 
         def __setitem__(s,a,v):
+            """Sets key `a` equal to `v` in the underlying dict.  """
             dct[a] = v
 
         def __getattr__(s,a):
+            """Returns value of key `a` from the underlying dict first, then from the defaults.
+
+            Raises `AttributeError` on failure.
+            """
             try:
                 return dct[a]
             except KeyError:
@@ -50,10 +77,10 @@
     #
     # load the experiment class 
     #
-    dbdict_module_name = getattr(state,MODULE)
-    dbdict_symbol = getattr(state, SYMBOL)
+    dbdict_module_name = state[MODULE]
+    dbdict_symbol = state[SYMBOL]
 
-    preimport_list = getattr(state, PREIMPORT, "").split()
+    preimport_list = state.get(PREIMPORT, "").split()
     for preimport in preimport_list:
         __import__(preimport, fromlist=[None], level=0)
 
@@ -75,3 +102,42 @@
         print >> sys.stderr, "WARNING: INVALID job function return value"
     return rval
 
+def call_with_kwargs_from_dict(fn, dct, logfile='stderr'):
+    """Call function `fn` with kwargs taken from dct.
+
+    When fn has a '**' parameter, this function is equivalent to fn(**dct).
+
+    When fn has no '**' parameter, this function removes keys from dct which are not parameter
+    names of `fn`.  The keys which are ignored in this way are logged to the `logfile`.  If
+    logfile is the string 'stdout' or 'stderr', then errors are logged to sys.stdout or
+    sys.stderr respectively.
+
+    :param fn: function to call
+    :param dct: dictionary from which to take arguments of fn
+    :param logfile: log ignored keys to this file
+    :type logfile: file-like object or string 'stdout' or string 'stderr'
+
+    :returns: fn(**<something>)
+
+    """
+    argspec = inspect.getargspec(fn)
+    argnames = argspec[0]
+    if argspec[2] == None: #if there is no room for a **args type-thing in fn...
+        kwargs = {}
+        for k,v in dct.items():
+            if k in argnames:
+                kwargs[k] = v
+            else:
+                if not logfile:
+                    pass
+                elif logfile == 'stderr':
+                    print >> sys.stderr, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v
+                elif logfile == 'stdout':
+                    print >> sys.stdout, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v
+                else:
+                    print >> logfile, "WARNING: DictProxyState.call_substate ignoring key-value pair:", k, v
+        return fn(**kwargs)
+    else:
+        #there is a **args type thing in fn. Here we pass everything.
+        return fn(**s.subdict(prefix))
+
--- a/pylearn/external/wrap_libsvm.py	Tue Nov 18 19:31:56 2008 -0500
+++ b/pylearn/external/wrap_libsvm.py	Tue Nov 18 19:42:12 2008 -0500
@@ -1,7 +1,8 @@
 """Run an experiment using libsvm.
 """
 import numpy
-from ..datasets import dataset_from_descr
+from ..datasets import make_dataset
+from ..dbdict.experiment import subdict_copy
 
 # libsvm currently has no python installation instructions/convention.
 #
@@ -54,30 +55,32 @@
     This is the kind of function that dbdict-run can use.
 
     """
-    ((train_x, train_y), (valid_x, valid_y), (test_x, test_y)) = dataset_from_descr(state.dataset)
+    dataset = make_dataset(**subdict_copy(state, 'dataset_'))
+            
+
 
     #libsvm needs stuff in int32 on a 32bit machine
     #TODO: test this on a 64bit machine
-    train_y = numpy.asarray(train_y, dtype='int32')
-    valid_y = numpy.asarray(valid_y, dtype='int32')
-    test_y = numpy.asarray(test_y, dtype='int32')
-    problem = svm.svm_problem(train_y, train_x);
+    train_y = numpy.asarray(dataset.train.y, dtype='int32')
+    valid_y = numpy.asarray(dataset.valid.y, dtype='int32')
+    test_y = numpy.asarray(dataset.test.y, dtype='int32')
+    problem = libsvm.svm_problem(train_y, dataset.train.x);
 
-    gamma0 = 0.5 / numpy.sum(numpy.var(train_x, axis=0))
+    gamma0 = 0.5 / numpy.sum(numpy.var(dataset.train.x, axis=0))
 
-    param = svm.svm_parameter(C=state.C,
-            kernel_type=getattr(svm, state.kernel),
-            gamma=state.rel_gamma * gamma0)
+    param = libsvm.svm_parameter(C=state['C'],
+            kernel_type=getattr(libsvm, state['kernel']),
+            gamma=state['rel_gamma'] * gamma0)
 
-    model = svm.svm_model(problem, param) #this is the expensive part
+    model = libsvm.svm_model(problem, param) #this is the expensive part
 
-    state.train_01 = score_01(train_x, train_y, model)
-    state.valid_01 = score_01(valid_x, valid_y, model)
-    state.test_01 = score_01(test_x, test_y, model)
+    state['train_01'] = score_01(dataset.train.x, train_y, model)
+    state['valid_01'] = score_01(dataset.valid.x, valid_y, model)
+    state['test_01'] = score_01(dataset.test.x, test_y, model)
 
-    state.n_train = len(train_y)
-    state.n_valid = len(valid_y)
-    state.n_test = len(test_y)
+    state['n_train'] = len(train_y)
+    state['n_valid'] = len(valid_y)
+    state['n_test'] = len(test_y)
 
 def run_svm_experiment(**kwargs):
     """Python-friendly interface to dbdict_run_svm_experiment
@@ -93,7 +96,6 @@
         # 0.14, 0.10  #.. or something...
 
     """
-    state = State(**kwargs)
-    state_run_svm_experiment(state)
-    return state
+    state_run_svm_experiment(state=kwargs)
+    return kwargs