changeset 566:a71971ccc1e4

checkpoint
author Olivier Breuleux <breuleuo@iro.umontreal.ca>
date Wed, 03 Dec 2008 16:07:04 -0500
parents 0ac4927e9d97
children d88c35e8f83a
files pylearn/dbdict/newstuff.py
diffstat 1 files changed, 244 insertions(+), 67 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dbdict/newstuff.py	Sat Nov 29 16:48:49 2008 -0500
+++ b/pylearn/dbdict/newstuff.py	Wed Dec 03 16:07:04 2008 -0500
@@ -1,6 +1,6 @@
 
 from collections import defaultdict
-import re
+import re, sys, inspect, os
 
 ################################################################################
 ### resolve
@@ -18,12 +18,14 @@
 ################################################################################
 
 def convert(obj):
+    if not isinstance(obj, str):
+        return obj
     def kw(x):
         x = x.lower()
         return dict(true = True,
                     false = False,
                     none = None)[x]
-    for f in (int, float, kw):
+    for f in (kw, int, float):
         try:
             return f(obj)
         except:
@@ -97,9 +99,14 @@
         elif len(s2) == 2:
             k, v = s2
             k += '.__builder__'
-        d[k] = v
+        d[k] = convert(v)
     return d
 
+def format_d(d, sep = '\n', space = True):
+    d = flatten(d)
+    pattern = "%s = %s" if space else "%s=%s"
+    return sep.join(pattern % (k, v) for k, v in d.iteritems())
+
 def format_help(topic):
     if topic is None:
         return 'No help.'
@@ -115,97 +122,267 @@
     return s
 
 ################################################################################
-### running
+### single channels
 ################################################################################
 
+try:
+    import greenlet
+except:
+    try:
+        from py import greenlet
+    except:
+        print >>sys.stderr, 'the greenlet module is unavailable'
+        greenlet = None
+
+class Complete(Exception):
+    pass
+class Incomplete(Exception):
+    pass
+
 class Channel(object):
-    def abort(self):
-        raise StopIteration()
-    def checkpoint(self):
-        raise NotImplementedError()
-    def switch(self, job, *message):
-        raise NotImplementedError()
-    def broadcast(self, *message):
+    COMPLETE = None
+    INCOMPLETE = True
+    def save(self):
         raise NotImplementedError()
-    def run(self):
+    def switch(self, *message):
+        pass
+    def process_message(self, message):
         pass
-
-class StandardChannel(Channel):
-    def __init__(self, runner):
-        self.runner = runner
-    def checkpoint(self):
+    def setup(self):
+        pass
+    def run(self):
         pass
 
 
 
+        def on_sigterm(signo, frame):
+            channel_rval[0] = 'stop'
+
+        #install a SIGTERM handler that asks the run_state function to return
+        signal.signal(signal.SIGTERM, on_sigterm)
+        signal.signal(signal.SIGINT, on_sigterm)
 
 
+class SingleChannel(Channel):
+    def __init__(self, experiment, state):
+        self.experiment = experiment
+        self.state = state
+#     def switch(self, *message):
+#         if greenlet:
+#             if greenlet.getcurrent() is self.expg:
+#                 self.feedback = message
+#                 self.manager.switch(message)
+#             else:
+#                 self.expg.switch(message)
+#         else:
+#             self.feedback = message
+    def run(self, interactive = False):
+        if interactive and not greenlet:
+            raise Exception('interactive mode requires the greenlet package to be installed (try easy_install greenlet or easy_install py)')
+        self.interactive = interactive
+        self.setup()
+        self.state['job'].setdefault('complete', False)
+        if self.state['job']['complete']:
+            raise Complete('The job has already completed.')
+#         if greenlet:
+# #             self.manager = greenlet.getcurrent()
+# #             self.expg = greenlet.greenlet(self.experiment)
+# #             expg.switch(self, self.state)
+#         else:
+        v = self.experiment(self, self.state)
+        self.state['job']['complete'] = v is COMPLETE
+        self.save()
+        return v
+
+class StandardChannel(SingleChannel):
+    def __init__(self, root, experiment, state):
+        self.root = root
+        self.experiment = experiment
+        self.state = state
+        self.dirname = format_d(self.state, sep=',', space=False)
+        self.path = os.path.join(self.root, self.dirname)
+    def save(self):
+        os.chdir(self.path)
+        current = open('current.conf', 'w')
+        current.write(format_d(self.state))
+        current.write('\n')
+    def setup(self):
+        if not os.path.isdir(self.path):
+            os.mkdir(self.path)
+        os.chdir(self.path)
+        if not os.path.isfile('orig.conf'):
+            orig = open('orig.conf', 'w')
+            orig.write(format_d(self.state))
+            orig.write('\n')
+        if os.path.isfile('current.conf'):
+            self.state = expand(parse(*map(str.strip, open('current.conf', 'r').readlines())))
+
+class RSyncException(Exception):
+    pass
+
+class RSyncChannel(StandardChannel):
+
+    def __init__(self, root, remote_root, experiment, state):
+        super(RSyncChannel, self).__init__(root, experiment, state)
+        self.remote_root = remote_root
+        self.remote_path = os.path.join(self.remote_root, self.dirname)
+
+    def rsync(self, direction):
+        """The directory at which experiment-related files are stored.
+
+        :returns: "<host>:<path>", of the sort used by ssh and rsync.
+        """
+
+        # TODO: redirect error better, use something more portable than os.system
+        if direction == 'push':
+            rsync_cmd = 'rsync -aq "%s/" "%s/" 2>/dev/null' % (self.path, self.remote_path)
+        elif direction == 'pull':
+            rsync_cmd = 'rsync -aq "%s/" "%s/" 2>/dev/null' % (self.remote_path, self.path)
+        else:
+            raise RSyncException('invalid direction', direction)
+
+        rsync_rval = os.system(rsync_cmd)
+        if rsync_rval != 0:
+            raise RSyncException('rsync failure', (rsync_rval, rsync_cmd))
+
+    def pull(self):
+        return self.rsync('pull')
+
+    def push(self):
+        return self.rsync('push')
+
+    def save(self):
+        super(RSyncChannel, self).save()
+        self.push()
+
+    def setup(self):
+        try:
+            self.pull()
+        except RSyncException:
+            # The experiment does not exist remotely, it's ok we will
+            # push it when we save.
+            pass
+        super(RSyncChannel, self).setup()
+
+class DBRSyncChannel(RSyncChannel):
+
+    def __init__(self, db, tablename, root, remote_root):
+        super(DBRsyncChannel, self).__init__(root, remote_root, None, None)
+        self.db = db
+        self.tablename = tablename
+
+    def save(self):
+        super(DBRSyncChannel, self).save()
+        # sync to db
+    
+    def setup(self):
+        # Extract a single experiment from the table that is not already running.
+        # set self.experiment and self.state
+        super(DBRSyncChannel, self).setup()
+
+
+################################################################################
+### multiple channels
+################################################################################
+
+class MultipleChannel(Channel):
+    def switch(self, job, *message):
+        raise NotImplementedError('This Channel does not allow switching between jobs.')
+    def broadcast(self, *message):
+        raise NotImplementedError()
+
+class SpawnChannel(MultipleChannel):
+    # spawns one process for each task
+    pass
+
+class GreenletChannel(MultipleChannel):
+    # uses a single process for all tasks, using greenlets to switch between them
+    pass
 
 
 
-class Run(object):
+################################################################################
+### running
+################################################################################
 
-    def __init__(self, type, arguments):
-        runner = getattr(self, 'run_%s' % type, None)
-        if not runner:
-            raise UsageError('Unknown runner: "%s"' % type)
-        self.type = type
-        self.runner = runner
-        self.arguments = arguments
-        if len(inspect.getargspec(runner)[0])-1 > len(arguments):
-            s = format_help(runner)
-            raise UsageError(s)
-        runner(*self.arguments)
+def run(type, arguments):
+    runner = runner_registry.get(type, None)
+    if not runner:
+        raise UsageError('Unknown runner: "%s"' % type)
+    argspec = inspect.getargspec(runner)
+    minargs = len(argspec[0])-(len(argspec[3]) if argspec[3] else 0)
+    maxargs = len(argspec[0])
+    if minargs > len(arguments) or maxargs < len(arguments) and not argspec[1]:
+        s = format_help(runner)
+        raise UsageError(s)
+    runner(*arguments)
 
-    def run_cmdline(self, experiment, *strings):
-        """
-        Usage: cmdline <experiment> <prop1::type> <prop1=value1> <prop2=value2> ...
+runner_registry = dict()
 
-        Run an experiment with parameters provided on the command
-        line.  The symbol described by <experiment> will be imported
-        using the normal python import rules and will be called with
-        the dictionary described on the command line.
+def cmdline(experiment, *strings):
+    """
+    Start an experiment with parameters given on the command line.
+
+    Usage: cmdline <experiment> <prop1::type> <prop1=value1> <prop2=value2> ...
 
-        The signature of the function located at <experiment> must
-        look like:
-            def my_experiment(state, channel):
-                ...
+    Run an experiment with parameters provided on the command
+    line.  The symbol described by <experiment> will be imported
+    using the normal python import rules and will be called with
+    the dictionary described on the command line.
 
-        Examples of setting parameters:
-            a=2 => state['a'] = 2
-            b.c=3 => state['b']['c'] = 3
-            p::mymodule.Something => state['p']['__builder__']=mymodule.Something
+    The signature of the function located at <experiment> must
+    look like:
+        def my_experiment(state, channel):
+            ...
 
-        Example call:
-            $EXE$ cmdline mymodule.my_experiment \\
-                stopper::pylearn.stopper.nsteps \\ # use pylearn.stopper.nsteps
-                stopper.n=10000 \\ # the argument "n" of nsteps is 10000
-                lr=0.03
-        """
-        state = expand(parse(*strings))
-        self.experiment = resolve(experiment)
-        self.channel = StandardChannel(self, state)
-        self.channel.run()
-        #experiment(d, CmdlineChannel(self))
-        #print flatten(d)
+    Examples of setting parameters:
+        a=2 => state['a'] = 2
+        b.c=3 => state['b']['c'] = 3
+        p::mymodule.Something => state['p']['__builder__']=mymodule.Something
 
-    def run_help(self, topic):
-        """
-        Usage: help <topic>
+    Example call:
+        run_experiment cmdline mymodule.my_experiment \\
+            stopper::pylearn.stopper.nsteps \\ # use pylearn.stopper.nsteps
+            stopper.n=10000 \\ # the argument "n" of nsteps is 10000
+            lr=0.03
+    """
+    state = expand(parse(*strings))
+    experiment = resolve(experiment)
+    channel = RSyncChannel(os.getcwd(), os.path.realpath('yaddayadda'), experiment, state)
+    channel.run()
 
-        Get help for a topic.
-        """
-        print format_help(getattr(self, 'run_%s' % topic, None))
+runner_registry['cmdline'] = cmdline
 
 
+def help(topic = None):
+    """
+    Get help for a topic.
 
+    Usage: help <topic>
+    """
+    if topic is None:
+        print 'Available commands: (use help <command> for more info)'
+        for name, command in sorted(runner_registry.iteritems()):
+            print name.ljust(20), format_help(command).split('\n')[0]
+        return
+    print format_help(runner_registry.get(topic, None))
+
+runner_registry['help'] = help
 
+################################################################################
+### main
+################################################################################
 
+def run_cmdline():
+    try:
+        if len(sys.argv) <= 1:
+            raise UsageError('Usage: %s <run_type> [<arguments>*]' % sys.argv[0])
+        run(sys.argv[1], sys.argv[2:])
+    except UsageError, e:
+        print 'Usage error:'
+        print e
 
+if __name__ == '__main__':
+    run_cmdline()
 
 
-
-
-
-
-