# HG changeset patch # User Olivier Breuleux # Date 1228338424 18000 # Node ID a71971ccc1e4bda4814f6bf8c35e9e65e43e209b # Parent 0ac4927e9d97e5052cd2587b3ea2f633e6ad239b checkpoint diff -r 0ac4927e9d97 -r a71971ccc1e4 pylearn/dbdict/newstuff.py --- a/pylearn/dbdict/newstuff.py Sat Nov 29 16:48:49 2008 -0500 +++ b/pylearn/dbdict/newstuff.py Wed Dec 03 16:07:04 2008 -0500 @@ -1,6 +1,6 @@ from collections import defaultdict -import re +import re, sys, inspect, os ################################################################################ ### resolve @@ -18,12 +18,14 @@ ################################################################################ def convert(obj): + if not isinstance(obj, str): + return obj def kw(x): x = x.lower() return dict(true = True, false = False, none = None)[x] - for f in (int, float, kw): + for f in (kw, int, float): try: return f(obj) except: @@ -97,9 +99,14 @@ elif len(s2) == 2: k, v = s2 k += '.__builder__' - d[k] = v + d[k] = convert(v) return d +def format_d(d, sep = '\n', space = True): + d = flatten(d) + pattern = "%s = %s" if space else "%s=%s" + return sep.join(pattern % (k, v) for k, v in d.iteritems()) + def format_help(topic): if topic is None: return 'No help.' @@ -115,97 +122,267 @@ return s ################################################################################ -### running +### single channels ################################################################################ +try: + import greenlet +except: + try: + from py import greenlet + except: + print >>sys.stderr, 'the greenlet module is unavailable' + greenlet = None + +class Complete(Exception): + pass +class Incomplete(Exception): + pass + class Channel(object): - def abort(self): - raise StopIteration() - def checkpoint(self): - raise NotImplementedError() - def switch(self, job, *message): - raise NotImplementedError() - def broadcast(self, *message): + COMPLETE = None + INCOMPLETE = True + def save(self): raise NotImplementedError() - def run(self): + def switch(self, *message): + pass + def process_message(self, message): pass - -class StandardChannel(Channel): - def __init__(self, runner): - self.runner = runner - def checkpoint(self): + def setup(self): + pass + def run(self): pass + def on_sigterm(signo, frame): + channel_rval[0] = 'stop' + + #install a SIGTERM handler that asks the run_state function to return + signal.signal(signal.SIGTERM, on_sigterm) + signal.signal(signal.SIGINT, on_sigterm) +class SingleChannel(Channel): + def __init__(self, experiment, state): + self.experiment = experiment + self.state = state +# def switch(self, *message): +# if greenlet: +# if greenlet.getcurrent() is self.expg: +# self.feedback = message +# self.manager.switch(message) +# else: +# self.expg.switch(message) +# else: +# self.feedback = message + def run(self, interactive = False): + if interactive and not greenlet: + raise Exception('interactive mode requires the greenlet package to be installed (try easy_install greenlet or easy_install py)') + self.interactive = interactive + self.setup() + self.state['job'].setdefault('complete', False) + if self.state['job']['complete']: + raise Complete('The job has already completed.') +# if greenlet: +# # self.manager = greenlet.getcurrent() +# # self.expg = greenlet.greenlet(self.experiment) +# # expg.switch(self, self.state) +# else: + v = self.experiment(self, self.state) + self.state['job']['complete'] = v is COMPLETE + self.save() + return v + +class StandardChannel(SingleChannel): + def __init__(self, root, experiment, state): + self.root = root + self.experiment = experiment + self.state = state + self.dirname = format_d(self.state, sep=',', space=False) + self.path = os.path.join(self.root, self.dirname) + def save(self): + os.chdir(self.path) + current = open('current.conf', 'w') + current.write(format_d(self.state)) + current.write('\n') + def setup(self): + if not os.path.isdir(self.path): + os.mkdir(self.path) + os.chdir(self.path) + if not os.path.isfile('orig.conf'): + orig = open('orig.conf', 'w') + orig.write(format_d(self.state)) + orig.write('\n') + if os.path.isfile('current.conf'): + self.state = expand(parse(*map(str.strip, open('current.conf', 'r').readlines()))) + +class RSyncException(Exception): + pass + +class RSyncChannel(StandardChannel): + + def __init__(self, root, remote_root, experiment, state): + super(RSyncChannel, self).__init__(root, experiment, state) + self.remote_root = remote_root + self.remote_path = os.path.join(self.remote_root, self.dirname) + + def rsync(self, direction): + """The directory at which experiment-related files are stored. + + :returns: ":", of the sort used by ssh and rsync. + """ + + # TODO: redirect error better, use something more portable than os.system + if direction == 'push': + rsync_cmd = 'rsync -aq "%s/" "%s/" 2>/dev/null' % (self.path, self.remote_path) + elif direction == 'pull': + rsync_cmd = 'rsync -aq "%s/" "%s/" 2>/dev/null' % (self.remote_path, self.path) + else: + raise RSyncException('invalid direction', direction) + + rsync_rval = os.system(rsync_cmd) + if rsync_rval != 0: + raise RSyncException('rsync failure', (rsync_rval, rsync_cmd)) + + def pull(self): + return self.rsync('pull') + + def push(self): + return self.rsync('push') + + def save(self): + super(RSyncChannel, self).save() + self.push() + + def setup(self): + try: + self.pull() + except RSyncException: + # The experiment does not exist remotely, it's ok we will + # push it when we save. + pass + super(RSyncChannel, self).setup() + +class DBRSyncChannel(RSyncChannel): + + def __init__(self, db, tablename, root, remote_root): + super(DBRsyncChannel, self).__init__(root, remote_root, None, None) + self.db = db + self.tablename = tablename + + def save(self): + super(DBRSyncChannel, self).save() + # sync to db + + def setup(self): + # Extract a single experiment from the table that is not already running. + # set self.experiment and self.state + super(DBRSyncChannel, self).setup() + + +################################################################################ +### multiple channels +################################################################################ + +class MultipleChannel(Channel): + def switch(self, job, *message): + raise NotImplementedError('This Channel does not allow switching between jobs.') + def broadcast(self, *message): + raise NotImplementedError() + +class SpawnChannel(MultipleChannel): + # spawns one process for each task + pass + +class GreenletChannel(MultipleChannel): + # uses a single process for all tasks, using greenlets to switch between them + pass -class Run(object): +################################################################################ +### running +################################################################################ - def __init__(self, type, arguments): - runner = getattr(self, 'run_%s' % type, None) - if not runner: - raise UsageError('Unknown runner: "%s"' % type) - self.type = type - self.runner = runner - self.arguments = arguments - if len(inspect.getargspec(runner)[0])-1 > len(arguments): - s = format_help(runner) - raise UsageError(s) - runner(*self.arguments) +def run(type, arguments): + runner = runner_registry.get(type, None) + if not runner: + raise UsageError('Unknown runner: "%s"' % type) + argspec = inspect.getargspec(runner) + minargs = len(argspec[0])-(len(argspec[3]) if argspec[3] else 0) + maxargs = len(argspec[0]) + if minargs > len(arguments) or maxargs < len(arguments) and not argspec[1]: + s = format_help(runner) + raise UsageError(s) + runner(*arguments) - def run_cmdline(self, experiment, *strings): - """ - Usage: cmdline ... +runner_registry = dict() - Run an experiment with parameters provided on the command - line. The symbol described by will be imported - using the normal python import rules and will be called with - the dictionary described on the command line. +def cmdline(experiment, *strings): + """ + Start an experiment with parameters given on the command line. + + Usage: cmdline ... - The signature of the function located at must - look like: - def my_experiment(state, channel): - ... + Run an experiment with parameters provided on the command + line. The symbol described by will be imported + using the normal python import rules and will be called with + the dictionary described on the command line. - Examples of setting parameters: - a=2 => state['a'] = 2 - b.c=3 => state['b']['c'] = 3 - p::mymodule.Something => state['p']['__builder__']=mymodule.Something + The signature of the function located at must + look like: + def my_experiment(state, channel): + ... - Example call: - $EXE$ cmdline mymodule.my_experiment \\ - stopper::pylearn.stopper.nsteps \\ # use pylearn.stopper.nsteps - stopper.n=10000 \\ # the argument "n" of nsteps is 10000 - lr=0.03 - """ - state = expand(parse(*strings)) - self.experiment = resolve(experiment) - self.channel = StandardChannel(self, state) - self.channel.run() - #experiment(d, CmdlineChannel(self)) - #print flatten(d) + Examples of setting parameters: + a=2 => state['a'] = 2 + b.c=3 => state['b']['c'] = 3 + p::mymodule.Something => state['p']['__builder__']=mymodule.Something - def run_help(self, topic): - """ - Usage: help + Example call: + run_experiment cmdline mymodule.my_experiment \\ + stopper::pylearn.stopper.nsteps \\ # use pylearn.stopper.nsteps + stopper.n=10000 \\ # the argument "n" of nsteps is 10000 + lr=0.03 + """ + state = expand(parse(*strings)) + experiment = resolve(experiment) + channel = RSyncChannel(os.getcwd(), os.path.realpath('yaddayadda'), experiment, state) + channel.run() - Get help for a topic. - """ - print format_help(getattr(self, 'run_%s' % topic, None)) +runner_registry['cmdline'] = cmdline +def help(topic = None): + """ + Get help for a topic. + Usage: help + """ + if topic is None: + print 'Available commands: (use help for more info)' + for name, command in sorted(runner_registry.iteritems()): + print name.ljust(20), format_help(command).split('\n')[0] + return + print format_help(runner_registry.get(topic, None)) + +runner_registry['help'] = help +################################################################################ +### main +################################################################################ +def run_cmdline(): + try: + if len(sys.argv) <= 1: + raise UsageError('Usage: %s [*]' % sys.argv[0]) + run(sys.argv[1], sys.argv[2:]) + except UsageError, e: + print 'Usage error:' + print e +if __name__ == '__main__': + run_cmdline() - - - - -