Mercurial > pylearn
changeset 568:1f036d934ad9
improvements to dbdict interface
author | Olivier Breuleux <breuleuo@iro.umontreal.ca> |
---|---|
date | Wed, 03 Dec 2008 22:26:31 -0500 |
parents | d88c35e8f83a |
children | eaf4cbd20017 |
files | pylearn/dbdict/api0.py pylearn/dbdict/newstuff.py pylearn/dbdict/sql.py |
diffstat | 3 files changed, 441 insertions(+), 92 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/dbdict/api0.py Wed Dec 03 18:52:34 2008 -0500 +++ b/pylearn/dbdict/api0.py Wed Dec 03 22:26:31 2008 -0500 @@ -180,6 +180,9 @@ s.commit() s.update(d_self) + def iteritems(d_self): + return d_self.items() + def items(d_self): return [(kv.name, kv.val) for kv in d_self._attrs]
--- a/pylearn/dbdict/newstuff.py Wed Dec 03 18:52:34 2008 -0500 +++ b/pylearn/dbdict/newstuff.py Wed Dec 03 22:26:31 2008 -0500 @@ -4,8 +4,25 @@ import FileLock import portalocker from collections import defaultdict -import re, sys, inspect, os -import signal +import re, sys, inspect, os, signal, tempfile, shutil + + +import sql + + +################################################################################ +### misc +################################################################################ + +class DD(defaultdict): + def __getattr__(self, attr): + return self[attr] + def __setattr__(self, attr, value): + self[attr] = value + def __str__(self): + return 'DD%s' % dict(self) + def __repr__(self): + return str(self) ################################################################################ ### resolve @@ -57,7 +74,7 @@ def expand(d): def dd(): - return defaultdict(dd) + return DD(dd) struct = dd() for k, v in d.iteritems(): if k == '': @@ -102,10 +119,11 @@ raise UsageError('Expected a keyword argument in place of "%s"' % s1[0]) elif len(s1) == 2: k, v = s1 + v = convert(v) elif len(s2) == 2: k, v = s2 k += '.__builder__' - d[k] = convert(v) + d[k] = v return d def format_d(d, sep = '\n', space = True): @@ -120,11 +138,18 @@ help = topic.help() else: help = topic.__doc__ + if not help: + return 'No help.' + ss = map(str.rstrip, help.split('\n')) - baseline = min([len(line) - len(line.lstrip()) for line in ss if line]) + try: + baseline = min([len(line) - len(line.lstrip()) for line in ss if line]) + except: + return 'No help.' s = '\n'.join([line[baseline:] for line in ss]) s = re.sub(string = s, pattern = '\n{2,}', repl = '\n\n') s = re.sub(string = s, pattern = '(^\n*)|(\n*$)', repl = '') + return s ################################################################################ @@ -175,7 +200,8 @@ """ pass - __call__ = switch + def __call__(self, message = None): + return self.switch(message) def save_and_switch(self): self.save() @@ -197,6 +223,12 @@ +class JobError(Exception): + RUNNING = 0 + DONE = 1 + NOJOB = 2 + + class SingleChannel(Channel): def __init__(self, experiment, state): @@ -204,49 +236,60 @@ self.state = state self.feedback = None - def switch(self, message): + def switch(self, message = None): feedback = self.feedback self.feedback = None return feedback - def run(self): - # install a SIGTERM handler that asks the experiment function to return - # the next time it will call switch() - def on_sigterm(signo, frame): - self.feedback = 'stop' - signal.signal(signal.SIGTERM, on_sigterm) - signal.signal(signal.SIGINT, on_sigterm) - + def run(self, force = False): self.setup() - status = self.state['dbdict'].get('status', self.START) - if status is self.DONE: - raise Exception('The job has already completed.') - elif status is self.RUNNING: - raise Exception('The job is already running.') - self.state['dbdict'].setdefault('status', self.START) + status = self.state.dbdict.get('status', self.START) + if status is self.DONE and not force: + # If you want to disregard this, use the --force flag (not yet implemented) + raise JobError(JobError.RUNNING, + 'The job has already completed.') + elif status is self.RUNNING and not force: + raise JobError(JobError.DONE, + 'The job is already running.') + self.state.dbdict.status = self.RUNNING + v = self.INCOMPLETE with self: - v = self.experiment(self, self.state) - self.state['dbdict']['status'] = self.DONE if v is self.COMPLETE else self.START + try: + v = self.experiment(self, self.state) + finally: + self.state.dbdict.status = self.DONE if v is self.COMPLETE else self.START return v + def on_sigterm(self, signo, frame): + # SIGTERM handler. It is the experiment function's responsibility to + # call switch() often enough to get this feedback. + self.feedback = 'stop' + def __enter__(self): + # install a SIGTERM handler that asks the experiment function to return + # the next time it will call switch() + self.prev_sigterm = signal.getsignal(signal.SIGTERM) + self.prev_sigint = signal.getsignal(signal.SIGINT) + signal.signal(signal.SIGTERM, self.on_sigterm) + signal.signal(signal.SIGINT, self.on_sigterm) return self def __exit__(self, type, value, traceback): + signal.signal(signal.SIGTERM, self.prev_sigterm) + signal.signal(signal.SIGINT, self.prev_sigint) + self.prev_sigterm = None + self.prev_sigint = None self.save() class StandardChannel(SingleChannel): - def __init__(self, root, experiment, state, redirect_stdout = False, redirect_stderr = False): - self.root = root - self.experiment = experiment - self.state = state - self.dirname = format_d(self.state, sep=',', space=False) - self.path = os.path.join(self.root, self.dirname) + def __init__(self, path, experiment, state, redirect_stdout = False, redirect_stderr = False): + super(StandardChannel, self).__init__(experiment, state) + self.path = os.path.realpath(path) self.lock = None self.redirect_stdout = redirect_stdout self.redirect_stderr = redirect_stderr @@ -257,12 +300,6 @@ current.write('\n') def __enter__(self): - ###assert self.lock is None - ##lockf = os.path.join(self.path, 'lock') - ##self.lock = open(lockf, 'r+') - ##portalocker.lock(self.lock, portalocker.LOCK_EX) - #self.lock = FileLock.FileLock(os.path.join(self.path, 'lock')) - #self.lock.lock() self.old_cwd = os.getcwd() os.chdir(self.path) if self.redirect_stdout: @@ -274,7 +311,6 @@ return super(StandardChannel, self).__enter__() def __exit__(self, type, value, traceback): - ###assert self.lock is not None if self.redirect_stdout: sys.stdout.close() sys.stdout = self.old_stdout @@ -282,14 +318,11 @@ sys.stderr.close() sys.stderr = self.old_stderr os.chdir(self.old_cwd) - ##self.lock.close() - #self.lock.unlock() - ###self.lock = None return super(StandardChannel, self).__exit__(type, value, traceback) def setup(self): if not os.path.isdir(self.path): - os.mkdir(self.path) + os.makedirs(self.path) with self: origf = os.path.join(self.path, 'orig.conf') if not os.path.isfile(origf): @@ -301,38 +334,40 @@ with open(currentf, 'r') as current: self.state = expand(parse(*map(str.strip, current.readlines()))) -# origf = os.path.join(self.path, 'orig.conf') -# if not os.path.isfile(origf): -# with open(os.path.isfile(origf), 'w') as orig: -# orig.write(format_d(self.state)) -# orig.write('\n') -# currentf = os.path.join(self.path, 'current.conf') -# if os.path.isfile(currentf): -# with open(currentf, 'r') as current: -# self.state = expand(parse(*map(str.strip, current.readlines()))) - class RSyncException(Exception): pass class RSyncChannel(StandardChannel): - def __init__(self, root, remote_root, experiment, state): - super(RSyncChannel, self).__init__(root, experiment, state) - self.remote_root = remote_root - self.remote_path = os.path.join(self.remote_root, self.dirname) + def __init__(self, path, remote_path, experiment, state): + super(RSyncChannel, self).__init__(path, experiment, state) + + ssh_prefix='ssh://' + if remote_path.startswith(ssh_prefix): + remote_path = remote_path[len(ssh_prefix):] + colon_pos = remote_path.find(':') + self.host = remote_path[:colon_pos] + self.remote_path = remote_path[colon_pos+1:] + else: + self.host = '' + self.remote_path = os.path.realpath(remote_path) def rsync(self, direction): """The directory at which experiment-related files are stored. - - :returns: "<host>:<path>", of the sort used by ssh and rsync. """ - # TODO: redirect error better, use something more portable than os.system + path = self.path + + remote_path = self.remote_path + if self.host: + remote_path = ':'.join([self.host, remote_path]) + + # TODO: use something more portable than os.system if direction == 'push': - rsync_cmd = 'rsync -aq "%s/" "%s/" 2>/dev/null' % (self.path, self.remote_path) + rsync_cmd = 'rsync -ar "%s/" "%s/"' % (path, remote_path) elif direction == 'pull': - rsync_cmd = 'rsync -aq "%s/" "%s/" 2>/dev/null' % (self.remote_path, self.path) + rsync_cmd = 'rsync -ar "%s/" "%s/"' % (remote_path, path) else: raise RSyncException('invalid direction', direction) @@ -340,6 +375,18 @@ if rsync_rval != 0: raise RSyncException('rsync failure', (rsync_rval, rsync_cmd)) + def touch(self): + if self.host: + host = self.host + touch_cmd = ('ssh %(host)s "mkdir -p \'%(path)s\'"' % dict(host = self.host, + path = self.remote_path)) + else: + touch_cmd = ("mkdir -p '%(path)s'" % dict(path = self.remote_path)) + print "ECHO", touch_cmd + touch_rval = os.system(touch_cmd) + if 0 != touch_rval: + raise Exception('touch failure', (touch_rval, touch_cmd)) + def pull(self): return self.rsync('pull') @@ -351,48 +398,48 @@ self.push() def setup(self): - try: - self.pull() - except RSyncException: - # The experiment does not exist remotely, it's ok we will - # push it when we save. - pass + self.touch() + self.pull() super(RSyncChannel, self).setup() + class DBRSyncChannel(RSyncChannel): - def __init__(self, db, tablename, root, remote_root): - super(DBRsyncChannel, self).__init__(root, remote_root, None, None) - self.db = db - self.tablename = tablename + def __init__(self, username, password, hostname, dbname, tablename, path, remote_root): + self.username, self.password, self.hostname, self.dbname, self.tablename \ + = username, password, hostname, dbname, tablename + + self.db = sql.postgres_serial( + user = self.username, + password = self.password, + host = self.hostname, + database = self.dbname, + table_prefix = self.tablename) + + self.dbstate = sql.book_dct_postgres_serial(self.db) + if self.dbstate is None: + raise JobError(JobError.NOJOB, + 'No job was found to run.') + + state = expand(self.dbstate) + experiment = resolve(state.dbdict.experiment) + + remote_path = os.path.join(remote_root, self.dbname, self.tablename, str(self.dbstate.id)) + super(DBRSyncChannel, self).__init__(path, remote_path, experiment, state) def save(self): super(DBRSyncChannel, self).save() - # sync to db + self.dbstate.update(flatten(self.state)) def setup(self): # Extract a single experiment from the table that is not already running. # set self.experiment and self.state super(DBRSyncChannel, self).setup() - -################################################################################ -### multiple channels -################################################################################ - -class MultipleChannel(Channel): - def switch(self, job, *message): - raise NotImplementedError('This Channel does not allow switching between jobs.') - def broadcast(self, *message): - raise NotImplementedError() - -class SpawnChannel(MultipleChannel): - # spawns one process for each task - pass - -class GreenletChannel(MultipleChannel): - # uses a single process for all tasks, using greenlets to switch between them - pass + def run(self): + # We pass the force flag as True because the status flag is + # already set to RUNNING by book_dct in __init__ + return super(DBRSyncChannel, self).run(force = True) @@ -414,7 +461,7 @@ runner_registry = dict() -def cmdline(experiment, *strings): +def runner_cmdline(experiment, *strings): """ Start an experiment with parameters given on the command line. @@ -442,12 +489,56 @@ lr=0.03 """ state = expand(parse(*strings)) + state.dbdict.experiment = experiment experiment = resolve(experiment) - #channel = RSyncChannel(os.getcwd(), os.path.realpath('yaddayadda'), experiment, state) - channel = StandardChannel(os.getcwd(), experiment, state) + #channel = RSyncChannel('.', 'yaddayadda', experiment, state) + channel = StandardChannel(format_d(state, sep=',', space = False), + experiment, state) channel.run() -runner_registry['cmdline'] = cmdline +runner_registry['cmdline'] = runner_cmdline + + +def runner_sqlschedule(dbdescr, experiment, *strings): + try: + username, password, hostname, dbname, tablename \ + = sql.parse_dbstring(dbdescr) + except: + raise UsageError('Wrong syntax for dbdescr') + + db = sql.postgres_serial( + user = username, + password = password, + host = hostname, + database = dbname, + table_prefix = tablename) + + state = parse(*strings) + state['dbdict.experiment'] = experiment + sql.add_experiments_to_db([state], db, verbose = 1) + +runner_registry['sqlschedule'] = runner_sqlschedule + + + +def runner_sql(dbdescr, exproot): + try: + username, password, hostname, dbname, tablename \ + = sql.parse_dbstring(dbdescr) + except: + raise UsageError('Wrong syntax for dbdescr') + workdir = tempfile.mkdtemp() + print 'wdir', workdir + channel = DBRSyncChannel(username, password, hostname, dbname, tablename, + workdir, + exproot) + channel.run() + shutil.rmtree(workdir, ignore_errors=True) + +runner_registry['sql'] = runner_sql + + + def help(topic = None): @@ -458,6 +549,7 @@ """ if topic is None: print 'Available commands: (use help <command> for more info)' + print for name, command in sorted(runner_registry.iteritems()): print name.ljust(20), format_help(command).split('\n')[0] return @@ -482,3 +574,43 @@ run_cmdline() + + + + + + + + + + + + + + + + + + + + + +# fuck this shit + +# ################################################################################ +# ### multiple channels +# ################################################################################ + +# class MultipleChannel(Channel): +# def switch(self, job, *message): +# raise NotImplementedError('This Channel does not allow switching between jobs.') +# def broadcast(self, *message): +# raise NotImplementedError() + +# class SpawnChannel(MultipleChannel): +# # spawns one process for each task +# pass + +# class GreenletChannel(MultipleChannel): +# # uses a single process for all tasks, using greenlets to switch between them +# pass
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dbdict/sql.py Wed Dec 03 22:26:31 2008 -0500 @@ -0,0 +1,214 @@ + +import sys + +import sqlalchemy +from sqlalchemy import create_engine, desc +from sqlalchemy.orm import eagerload +import copy + +import psycopg2, psycopg2.extensions + +from api0 import db_from_engine, postgres_db + + +STATUS = 'dbdict.status' +PRIORITY = 'dbdict.sql.priority' +HOST = 'dbdict.sql.hostname' +HOST_WORKDIR = 'dbdict.sql.host_workdir' +PUSH_ERROR = 'dbdict.sql.push_error' + +START = 0 +RUNNING = 1 +DONE = 2 + +_TEST_CONCURRENCY = False + +def postgres_serial(user, password, host, database, **kwargs): + """Return a DbHandle instance that communicates with a postgres database at transaction + isolation_level 'SERIALIZABLE'. + + :param user: a username in the database + :param password: the password for the username + :param host: the network address of the host on which the postgres server is running + :param database: a database served by the postgres server + + """ + this = postgres_serial + + if not hasattr(this,'engine'): + def connect(): + c = psycopg2.connect(user=user, password=password, database=database, host=host) + c.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE) + return c + pool_size = 0 + this.engine = create_engine('postgres://' + ,creator=connect + ,pool_size=0 # should force the app release connections + ) + + db = db_from_engine(this.engine, **kwargs) + db._is_serialized_session_db = True + return db + +def book_dct_postgres_serial(db, retry_max_sleep=10.0, verbose=0): + """Find a trial in the lisa_db with status START. + + A trial will be returned with dbdict_status=RUNNING. + + Returns None if no such trial exists in DB. + + This function uses a serial access to the lisadb to guarantee that no other + process will retrieve the same dct. It is designed to facilitate writing + a "consumer" for a Producer-Consumer pattern based on the database. + + """ + print >> sys.stderr, """#TODO: use the priority field, not the status.""" + print >> sys.stderr, """#TODO: ignore entries with key PUSH_ERROR.""" + + s = db._session + + # NB. we need the query and attribute update to be in the same transaction + assert s.autocommit == False + + dcts_seen = set([]) + keep_trying = True + + dct = None + while (dct is None) and keep_trying: + #build a query + q = s.query(db._Dict) + q = q.options(eagerload('_attrs')) #hard-coded in api0 + q = q.filter(db._Dict._attrs.any(name=STATUS, fval=START)) + + #try to reserve a dct + try: + #first() may raise psycopg2.ProgrammingError + dct = q.first() + + if dct is not None: + assert (dct not in dcts_seen) + if verbose: print 'book_unstarted_dct retrieved, ', dct + dct._set_in_session(STATUS, RUNNING, s) + if 1: + if _TEST_CONCURRENCY: + print >> sys.stderr, 'SLEEPING BEFORE BOOKING' + time.sleep(10) + + #commit() may raise psycopg2.ProgrammingError + s.commit() + else: + print >> sys.stderr, 'DEBUG MODE: NOT RESERVING JOB!', dct + #if we get this far, the job is ours! + else: + # no jobs are left + keep_trying = False + except (psycopg2.OperationalError, + sqlalchemy.exceptions.ProgrammingError), e: + #either the first() or the commit() raised + s.rollback() # docs say to do this (or close) after commit raises exception + if verbose: print 'caught exception', e + if dct: + # first() succeeded, commit() failed + dcts_seen.add(dct) + dct = None + wait = numpy.random.rand(1)*retry_max_sleep + if verbose: print 'another process stole our dct. Waiting %f secs' % wait + time.sleep(wait) + return dct + +def book_dct(db): + print >> sys.stderr, """#TODO: use the priority field, not the status.""" + print >> sys.stderr, """#TODO: ignore entries with key self.push_error.""" + + return db.query(dbdict_status=START).first() + +def parse_dbstring(dbstring): + postgres = 'postgres://' + assert dbstring.startswith(postgres) + dbstring = dbstring[len(postgres):] + + #username_and_password + colon_pos = dbstring.find('@') + username_and_password = dbstring[:colon_pos] + dbstring = dbstring[colon_pos+1:] + + colon_pos = username_and_password.find(':') + if -1 == colon_pos: + username = username_and_password + password = None + else: + username = username_and_password[:colon_pos] + password = username_and_password[colon_pos+1:] + + #hostname + colon_pos = dbstring.find('/') + hostname = dbstring[:colon_pos] + dbstring = dbstring[colon_pos+1:] + + #dbname + colon_pos = dbstring.find('/') + dbname = dbstring[:colon_pos] + dbstring = dbstring[colon_pos+1:] + + #tablename + tablename = dbstring + + if password is None: + password = open(os.getenv('HOME')+'/.dbdict_%s'%dbname).readline()[:-1] + if False: + print 'USERNAME', username + print 'PASS', password + print 'HOST', hostname + print 'DB', dbname + print 'TABLE', tablename + + return username, password, hostname, dbname, tablename + + +def add_experiments_to_db(jobs, db, verbose=0, add_dups=False, type_check=None): + """Add experiments paramatrized by jobs[i] to database db. + + Default behaviour is to ignore jobs which are already in the database. + + If type_check is a class (instead of None) then it will be used as a type declaration for + all the elements in each job dictionary. For each key,value pair in the dictionary, there + must exist an attribute,value pair in the class meeting the following criteria: + the attribute and the key are equal, and the types of the values are equal. + + :param jobs: The parameters of experiments to run. + :type jobs: an iterable object over dictionaries + :param verbose: print which jobs are added and which are skipped + :param add_dups: False will ignore a job if it matches (on all items()) with a db entry. + :type add_dups: Bool + + :returns: list of (Bool,job[i]) in which the flags mean the corresponding job actually was + inserted. + + """ + rval = [] + for job in jobs: + job = copy.copy(job) + do_insert = add_dups or (None is db.query(**job).first()) + + if do_insert: + if type_check: + for k,v in job.items(): + if type(v) != getattr(type_check, k): + raise TypeError('Experiment contains value with wrong type', + ((k,v), getattr(type_check, k))) + + job[STATUS] = START + job[PRIORITY] = 1.0 + if verbose: + print 'ADDING ', job + db.insert(job) + rval.append((True, job)) + else: + if verbose: + print 'SKIPPING', job + rval.append((False, job)) + + + + +