Mercurial > pylearn
view pylearn/dbdict/newstuff.py @ 571:13bc6620ad95
removed references to file locking packages
author | Olivier Breuleux <breuleuo@iro.umontreal.ca> |
---|---|
date | Wed, 03 Dec 2008 22:29:40 -0500 |
parents | 1f036d934ad9 |
children | 9f5891cd4048 |
line wrap: on
line source
from __future__ import with_statement from collections import defaultdict import re, sys, inspect, os, signal, tempfile, shutil import sql ################################################################################ ### misc ################################################################################ class DD(defaultdict): def __getattr__(self, attr): return self[attr] def __setattr__(self, attr, value): self[attr] = value def __str__(self): return 'DD%s' % dict(self) def __repr__(self): return str(self) ################################################################################ ### resolve ################################################################################ def resolve(name): symbols = name.split('.') builder = __import__(symbols[0]) for sym in symbols[1:]: builder = getattr(builder, sym) return builder ################################################################################ ### dictionary ################################################################################ def convert(obj): return eval(obj) # if not isinstance(obj, str): # return obj # def kw(x): # x = x.lower() # return dict(true = True, # false = False, # none = None)[x] # for f in (kw, int, float): # try: # return f(obj) # except: # pass # return obj def flatten(obj): d = {} def helper(d, prefix, obj): if isinstance(obj, (str, int, float)): d[prefix] = obj #convert(obj) else: if isinstance(obj, dict): subd = obj else: subd = obj.state() subd['__builder__'] = '%s.%s' % (obj.__module__, obj.__class__.__name__) for k, v in subd.iteritems(): pfx = '.'.join([prefix, k]) if prefix else k helper(d, pfx, v) helper(d, '', obj) return d def expand(d): def dd(): return DD(dd) struct = dd() for k, v in d.iteritems(): if k == '': raise NotImplementedError() else: keys = k.split('.') current = struct for k2 in keys[:-1]: current = current[k2] current[keys[-1]] = v #convert(v) return struct def realize(d): if not isinstance(d, dict): return d d = dict((k, realize(v)) for k, v in d.iteritems()) if '__builder__' in d: builder = resolve(d.pop('__builder__')) return builder(**d) return d def make(d): return realize(expand(d)) ################################################################################ ### errors ################################################################################ class UsageError(Exception): pass ################################################################################ ### parsing and formatting ################################################################################ def parse(*strings): d = {} for string in strings: s1 = re.split(' *= *', string, 1) s2 = re.split(' *:: *', string, 1) if len(s1) == 1 and len(s2) == 1: raise UsageError('Expected a keyword argument in place of "%s"' % s1[0]) elif len(s1) == 2: k, v = s1 v = convert(v) elif len(s2) == 2: k, v = s2 k += '.__builder__' d[k] = v return d def format_d(d, sep = '\n', space = True): d = flatten(d) pattern = "%s = %r" if space else "%s=%r" return sep.join(pattern % (k, v) for k, v in d.iteritems()) def format_help(topic): if topic is None: return 'No help.' elif hasattr(topic, 'help'): help = topic.help() else: help = topic.__doc__ if not help: return 'No help.' ss = map(str.rstrip, help.split('\n')) try: baseline = min([len(line) - len(line.lstrip()) for line in ss if line]) except: return 'No help.' s = '\n'.join([line[baseline:] for line in ss]) s = re.sub(string = s, pattern = '\n{2,}', repl = '\n\n') s = re.sub(string = s, pattern = '(^\n*)|(\n*$)', repl = '') return s ################################################################################ ### single channels ################################################################################ # try: # import greenlet # except: # try: # from py import greenlet # except: # print >>sys.stderr, 'the greenlet module is unavailable' # greenlet = None class Channel(object): COMPLETE = None INCOMPLETE = True START = 0 """dbdict.status == START means a experiment is ready to run""" RUNNING = 1 """dbdict.status == RUNNING means a experiment is running on dbdict_hostname""" DONE = 2 """dbdict.status == DONE means a experiment has completed (not necessarily successfully)""" # Methods to be used by the experiment to communicate with the channel def save(self): """ Save the experiment's state to the various media supported by the Channel. """ raise NotImplementedError() def switch(self, message = None): """ Called from the experiment to give the control back to the channel. The following return values are meaningful: * 'stop' -> the experiment must stop as soon as possible. It may save what it needs to save. This occurs when SIGTERM or SIGINT are sent (or in user-defined circumstances). switch() may give the control to the user. In this case, the user may resume the experiment by calling switch() again. If an argument is given by the user, it will be relayed to the experiment. """ pass def __call__(self, message = None): return self.switch(message) def save_and_switch(self): self.save() self.switch() # Methods to run the experiment def setup(self): pass def __enter__(self): pass def __exit__(self): pass def run(self): pass class JobError(Exception): RUNNING = 0 DONE = 1 NOJOB = 2 class SingleChannel(Channel): def __init__(self, experiment, state): self.experiment = experiment self.state = state self.feedback = None def switch(self, message = None): feedback = self.feedback self.feedback = None return feedback def run(self, force = False): self.setup() status = self.state.dbdict.get('status', self.START) if status is self.DONE and not force: # If you want to disregard this, use the --force flag (not yet implemented) raise JobError(JobError.RUNNING, 'The job has already completed.') elif status is self.RUNNING and not force: raise JobError(JobError.DONE, 'The job is already running.') self.state.dbdict.status = self.RUNNING v = self.INCOMPLETE with self: try: v = self.experiment(self, self.state) finally: self.state.dbdict.status = self.DONE if v is self.COMPLETE else self.START return v def on_sigterm(self, signo, frame): # SIGTERM handler. It is the experiment function's responsibility to # call switch() often enough to get this feedback. self.feedback = 'stop' def __enter__(self): # install a SIGTERM handler that asks the experiment function to return # the next time it will call switch() self.prev_sigterm = signal.getsignal(signal.SIGTERM) self.prev_sigint = signal.getsignal(signal.SIGINT) signal.signal(signal.SIGTERM, self.on_sigterm) signal.signal(signal.SIGINT, self.on_sigterm) return self def __exit__(self, type, value, traceback): signal.signal(signal.SIGTERM, self.prev_sigterm) signal.signal(signal.SIGINT, self.prev_sigint) self.prev_sigterm = None self.prev_sigint = None self.save() class StandardChannel(SingleChannel): def __init__(self, path, experiment, state, redirect_stdout = False, redirect_stderr = False): super(StandardChannel, self).__init__(experiment, state) self.path = os.path.realpath(path) self.redirect_stdout = redirect_stdout self.redirect_stderr = redirect_stderr def save(self): with open(os.path.join(self.path, 'current.conf'), 'w') as current: current.write(format_d(self.state)) current.write('\n') def __enter__(self): self.old_cwd = os.getcwd() os.chdir(self.path) if self.redirect_stdout: self.old_stdout = sys.stdout sys.stdout = open('stdout', 'a') if self.redirect_stderr: self.old_stderr = sys.stderr sys.stderr = open('stderr', 'a') return super(StandardChannel, self).__enter__() def __exit__(self, type, value, traceback): if self.redirect_stdout: sys.stdout.close() sys.stdout = self.old_stdout if self.redirect_stderr: sys.stderr.close() sys.stderr = self.old_stderr os.chdir(self.old_cwd) return super(StandardChannel, self).__exit__(type, value, traceback) def setup(self): if not os.path.isdir(self.path): os.makedirs(self.path) with self: origf = os.path.join(self.path, 'orig.conf') if not os.path.isfile(origf): with open(origf, 'w') as orig: orig.write(format_d(self.state)) orig.write('\n') currentf = os.path.join(self.path, 'current.conf') if os.path.isfile(currentf): with open(currentf, 'r') as current: self.state = expand(parse(*map(str.strip, current.readlines()))) class RSyncException(Exception): pass class RSyncChannel(StandardChannel): def __init__(self, path, remote_path, experiment, state): super(RSyncChannel, self).__init__(path, experiment, state) ssh_prefix='ssh://' if remote_path.startswith(ssh_prefix): remote_path = remote_path[len(ssh_prefix):] colon_pos = remote_path.find(':') self.host = remote_path[:colon_pos] self.remote_path = remote_path[colon_pos+1:] else: self.host = '' self.remote_path = os.path.realpath(remote_path) def rsync(self, direction): """The directory at which experiment-related files are stored. """ path = self.path remote_path = self.remote_path if self.host: remote_path = ':'.join([self.host, remote_path]) # TODO: use something more portable than os.system if direction == 'push': rsync_cmd = 'rsync -ar "%s/" "%s/"' % (path, remote_path) elif direction == 'pull': rsync_cmd = 'rsync -ar "%s/" "%s/"' % (remote_path, path) else: raise RSyncException('invalid direction', direction) rsync_rval = os.system(rsync_cmd) if rsync_rval != 0: raise RSyncException('rsync failure', (rsync_rval, rsync_cmd)) def touch(self): if self.host: host = self.host touch_cmd = ('ssh %(host)s "mkdir -p \'%(path)s\'"' % dict(host = self.host, path = self.remote_path)) else: touch_cmd = ("mkdir -p '%(path)s'" % dict(path = self.remote_path)) print "ECHO", touch_cmd touch_rval = os.system(touch_cmd) if 0 != touch_rval: raise Exception('touch failure', (touch_rval, touch_cmd)) def pull(self): return self.rsync('pull') def push(self): return self.rsync('push') def save(self): super(RSyncChannel, self).save() self.push() def setup(self): self.touch() self.pull() super(RSyncChannel, self).setup() class DBRSyncChannel(RSyncChannel): def __init__(self, username, password, hostname, dbname, tablename, path, remote_root): self.username, self.password, self.hostname, self.dbname, self.tablename \ = username, password, hostname, dbname, tablename self.db = sql.postgres_serial( user = self.username, password = self.password, host = self.hostname, database = self.dbname, table_prefix = self.tablename) self.dbstate = sql.book_dct_postgres_serial(self.db) if self.dbstate is None: raise JobError(JobError.NOJOB, 'No job was found to run.') state = expand(self.dbstate) experiment = resolve(state.dbdict.experiment) remote_path = os.path.join(remote_root, self.dbname, self.tablename, str(self.dbstate.id)) super(DBRSyncChannel, self).__init__(path, remote_path, experiment, state) def save(self): super(DBRSyncChannel, self).save() self.dbstate.update(flatten(self.state)) def setup(self): # Extract a single experiment from the table that is not already running. # set self.experiment and self.state super(DBRSyncChannel, self).setup() def run(self): # We pass the force flag as True because the status flag is # already set to RUNNING by book_dct in __init__ return super(DBRSyncChannel, self).run(force = True) ################################################################################ ### running ################################################################################ def run(type, arguments): runner = runner_registry.get(type, None) if not runner: raise UsageError('Unknown runner: "%s"' % type) argspec = inspect.getargspec(runner) minargs = len(argspec[0])-(len(argspec[3]) if argspec[3] else 0) maxargs = len(argspec[0]) if minargs > len(arguments) or maxargs < len(arguments) and not argspec[1]: s = format_help(runner) raise UsageError(s) runner(*arguments) runner_registry = dict() def runner_cmdline(experiment, *strings): """ Start an experiment with parameters given on the command line. Usage: cmdline <experiment> <prop1::type> <prop1=value1> <prop2=value2> ... Run an experiment with parameters provided on the command line. The symbol described by <experiment> will be imported using the normal python import rules and will be called with the dictionary described on the command line. The signature of the function located at <experiment> must look like: def my_experiment(state, channel): ... Examples of setting parameters: a=2 => state['a'] = 2 b.c=3 => state['b']['c'] = 3 p::mymodule.Something => state['p']['__builder__']=mymodule.Something Example call: run_experiment cmdline mymodule.my_experiment \\ stopper::pylearn.stopper.nsteps \\ # use pylearn.stopper.nsteps stopper.n=10000 \\ # the argument "n" of nsteps is 10000 lr=0.03 """ state = expand(parse(*strings)) state.dbdict.experiment = experiment experiment = resolve(experiment) #channel = RSyncChannel('.', 'yaddayadda', experiment, state) channel = StandardChannel(format_d(state, sep=',', space = False), experiment, state) channel.run() runner_registry['cmdline'] = runner_cmdline def runner_sqlschedule(dbdescr, experiment, *strings): try: username, password, hostname, dbname, tablename \ = sql.parse_dbstring(dbdescr) except: raise UsageError('Wrong syntax for dbdescr') db = sql.postgres_serial( user = username, password = password, host = hostname, database = dbname, table_prefix = tablename) state = parse(*strings) state['dbdict.experiment'] = experiment sql.add_experiments_to_db([state], db, verbose = 1) runner_registry['sqlschedule'] = runner_sqlschedule def runner_sql(dbdescr, exproot): try: username, password, hostname, dbname, tablename \ = sql.parse_dbstring(dbdescr) except: raise UsageError('Wrong syntax for dbdescr') workdir = tempfile.mkdtemp() print 'wdir', workdir channel = DBRSyncChannel(username, password, hostname, dbname, tablename, workdir, exproot) channel.run() shutil.rmtree(workdir, ignore_errors=True) runner_registry['sql'] = runner_sql def help(topic = None): """ Get help for a topic. Usage: help <topic> """ if topic is None: print 'Available commands: (use help <command> for more info)' print for name, command in sorted(runner_registry.iteritems()): print name.ljust(20), format_help(command).split('\n')[0] return print format_help(runner_registry.get(topic, None)) runner_registry['help'] = help ################################################################################ ### main ################################################################################ def run_cmdline(): try: if len(sys.argv) <= 1: raise UsageError('Usage: %s <run_type> [<arguments>*]' % sys.argv[0]) run(sys.argv[1], sys.argv[2:]) except UsageError, e: print 'Usage error:' print e if __name__ == '__main__': run_cmdline() # fuck this shit # ################################################################################ # ### multiple channels # ################################################################################ # class MultipleChannel(Channel): # def switch(self, job, *message): # raise NotImplementedError('This Channel does not allow switching between jobs.') # def broadcast(self, *message): # raise NotImplementedError() # class SpawnChannel(MultipleChannel): # # spawns one process for each task # pass # class GreenletChannel(MultipleChannel): # # uses a single process for all tasks, using greenlets to switch between them # pass