changeset 568:1f036d934ad9

improvements to dbdict interface
author Olivier Breuleux <breuleuo@iro.umontreal.ca>
date Wed, 03 Dec 2008 22:26:31 -0500
parents d88c35e8f83a
children eaf4cbd20017
files pylearn/dbdict/api0.py pylearn/dbdict/newstuff.py pylearn/dbdict/sql.py
diffstat 3 files changed, 441 insertions(+), 92 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dbdict/api0.py	Wed Dec 03 18:52:34 2008 -0500
+++ b/pylearn/dbdict/api0.py	Wed Dec 03 22:26:31 2008 -0500
@@ -180,6 +180,9 @@
                 s.commit()
                 s.update(d_self)
 
+            def iteritems(d_self):
+                return d_self.items()
+
             def items(d_self):
                 return [(kv.name, kv.val) for kv in d_self._attrs]
             
--- a/pylearn/dbdict/newstuff.py	Wed Dec 03 18:52:34 2008 -0500
+++ b/pylearn/dbdict/newstuff.py	Wed Dec 03 22:26:31 2008 -0500
@@ -4,8 +4,25 @@
 import FileLock
 import portalocker
 from collections import defaultdict
-import re, sys, inspect, os
-import signal
+import re, sys, inspect, os, signal, tempfile, shutil
+
+
+import sql
+
+
+################################################################################
+### misc
+################################################################################
+
+class DD(defaultdict):
+    def __getattr__(self, attr):
+        return self[attr]
+    def __setattr__(self, attr, value):
+        self[attr] = value
+    def __str__(self):
+        return 'DD%s' % dict(self)
+    def __repr__(self):
+        return str(self)
 
 ################################################################################
 ### resolve
@@ -57,7 +74,7 @@
 
 def expand(d):
     def dd():
-        return defaultdict(dd)
+        return DD(dd)
     struct = dd()
     for k, v in d.iteritems():
         if k == '':
@@ -102,10 +119,11 @@
             raise UsageError('Expected a keyword argument in place of "%s"' % s1[0])
         elif len(s1) == 2:
             k, v = s1
+            v = convert(v)
         elif len(s2) == 2:
             k, v = s2
             k += '.__builder__'
-        d[k] = convert(v)
+        d[k] = v
     return d
 
 def format_d(d, sep = '\n', space = True):
@@ -120,11 +138,18 @@
         help = topic.help()
     else:
         help = topic.__doc__
+    if not help:
+        return 'No help.'
+
     ss = map(str.rstrip, help.split('\n'))
-    baseline = min([len(line) - len(line.lstrip()) for line in ss if line])
+    try:
+        baseline = min([len(line) - len(line.lstrip()) for line in ss if line])
+    except:
+        return 'No help.'
     s = '\n'.join([line[baseline:] for line in ss])
     s = re.sub(string = s, pattern = '\n{2,}', repl = '\n\n')
     s = re.sub(string = s, pattern = '(^\n*)|(\n*$)', repl = '')
+
     return s
 
 ################################################################################
@@ -175,7 +200,8 @@
         """
         pass
 
-    __call__ = switch
+    def __call__(self, message = None):
+        return self.switch(message)
 
     def save_and_switch(self):
         self.save()
@@ -197,6 +223,12 @@
 
 
 
+class JobError(Exception):
+    RUNNING = 0
+    DONE = 1
+    NOJOB = 2
+
+
 class SingleChannel(Channel):
 
     def __init__(self, experiment, state):
@@ -204,49 +236,60 @@
         self.state = state
         self.feedback = None
 
-    def switch(self, message):
+    def switch(self, message = None):
         feedback = self.feedback
         self.feedback = None
         return feedback
 
-    def run(self):
-        # install a SIGTERM handler that asks the experiment function to return
-        # the next time it will call switch()
-        def on_sigterm(signo, frame):
-            self.feedback = 'stop'
-        signal.signal(signal.SIGTERM, on_sigterm)
-        signal.signal(signal.SIGINT, on_sigterm)
-
+    def run(self, force = False):
         self.setup()
 
-        status = self.state['dbdict'].get('status', self.START)
-        if status is self.DONE:
-            raise Exception('The job has already completed.')
-        elif status is self.RUNNING:
-            raise Exception('The job is already running.')
-        self.state['dbdict'].setdefault('status', self.START)
+        status = self.state.dbdict.get('status', self.START)
+        if status is self.DONE and not force:
+            # If you want to disregard this, use the --force flag (not yet implemented)
+            raise JobError(JobError.RUNNING,
+                           'The job has already completed.')
+        elif status is self.RUNNING and not force:
+            raise JobError(JobError.DONE,
+                           'The job is already running.')
+        self.state.dbdict.status = self.RUNNING
 
+        v = self.INCOMPLETE
         with self:
-            v = self.experiment(self, self.state)
-            self.state['dbdict']['status'] = self.DONE if v is self.COMPLETE else self.START
+            try:
+                v = self.experiment(self, self.state)
+            finally:
+                self.state.dbdict.status = self.DONE if v is self.COMPLETE else self.START
 
         return v
 
+    def on_sigterm(self, signo, frame):
+        # SIGTERM handler. It is the experiment function's responsibility to
+        # call switch() often enough to get this feedback.
+        self.feedback = 'stop'
+
     def __enter__(self):
+        # install a SIGTERM handler that asks the experiment function to return
+        # the next time it will call switch()
+        self.prev_sigterm = signal.getsignal(signal.SIGTERM)
+        self.prev_sigint = signal.getsignal(signal.SIGINT)
+        signal.signal(signal.SIGTERM, self.on_sigterm)
+        signal.signal(signal.SIGINT, self.on_sigterm)
         return self
 
     def __exit__(self, type, value, traceback):
+        signal.signal(signal.SIGTERM, self.prev_sigterm)
+        signal.signal(signal.SIGINT, self.prev_sigint)
+        self.prev_sigterm = None
+        self.prev_sigint = None
         self.save()
 
 
 class StandardChannel(SingleChannel):
 
-    def __init__(self, root, experiment, state, redirect_stdout = False, redirect_stderr = False):
-        self.root = root
-        self.experiment = experiment
-        self.state = state
-        self.dirname = format_d(self.state, sep=',', space=False)
-        self.path = os.path.join(self.root, self.dirname)
+    def __init__(self, path, experiment, state, redirect_stdout = False, redirect_stderr = False):
+        super(StandardChannel, self).__init__(experiment, state)
+        self.path = os.path.realpath(path)
         self.lock = None
         self.redirect_stdout = redirect_stdout
         self.redirect_stderr = redirect_stderr
@@ -257,12 +300,6 @@
             current.write('\n')
 
     def __enter__(self):
-        ###assert self.lock is None
-        ##lockf = os.path.join(self.path, 'lock')
-        ##self.lock = open(lockf, 'r+')
-        ##portalocker.lock(self.lock, portalocker.LOCK_EX)
-        #self.lock = FileLock.FileLock(os.path.join(self.path, 'lock'))
-        #self.lock.lock()
         self.old_cwd = os.getcwd()
         os.chdir(self.path)
         if self.redirect_stdout:
@@ -274,7 +311,6 @@
         return super(StandardChannel, self).__enter__()
 
     def __exit__(self, type, value, traceback):
-        ###assert self.lock is not None
         if self.redirect_stdout:
             sys.stdout.close()
             sys.stdout = self.old_stdout
@@ -282,14 +318,11 @@
             sys.stderr.close()
             sys.stderr = self.old_stderr
         os.chdir(self.old_cwd)
-        ##self.lock.close()
-        #self.lock.unlock()
-        ###self.lock = None
         return super(StandardChannel, self).__exit__(type, value, traceback)
 
     def setup(self):
         if not os.path.isdir(self.path):
-            os.mkdir(self.path)
+            os.makedirs(self.path)
         with self:
             origf = os.path.join(self.path, 'orig.conf')
             if not os.path.isfile(origf):
@@ -301,38 +334,40 @@
                 with open(currentf, 'r') as current:
                     self.state = expand(parse(*map(str.strip, current.readlines())))
 
-#         origf = os.path.join(self.path, 'orig.conf')
-#         if not os.path.isfile(origf):
-#             with open(os.path.isfile(origf), 'w') as orig:
-#                 orig.write(format_d(self.state))
-#                 orig.write('\n')
-#         currentf = os.path.join(self.path, 'current.conf')
-#         if os.path.isfile(currentf):
-#             with open(currentf, 'r') as current:
-#                 self.state = expand(parse(*map(str.strip, current.readlines())))
-            
 
 class RSyncException(Exception):
     pass
 
 class RSyncChannel(StandardChannel):
 
-    def __init__(self, root, remote_root, experiment, state):
-        super(RSyncChannel, self).__init__(root, experiment, state)
-        self.remote_root = remote_root
-        self.remote_path = os.path.join(self.remote_root, self.dirname)
+    def __init__(self, path, remote_path, experiment, state):
+        super(RSyncChannel, self).__init__(path, experiment, state)
+
+        ssh_prefix='ssh://'
+        if remote_path.startswith(ssh_prefix):
+            remote_path = remote_path[len(ssh_prefix):]
+            colon_pos = remote_path.find(':')
+            self.host = remote_path[:colon_pos]
+            self.remote_path = remote_path[colon_pos+1:]
+        else:
+            self.host = ''
+            self.remote_path = os.path.realpath(remote_path)
 
     def rsync(self, direction):
         """The directory at which experiment-related files are stored.
-
-        :returns: "<host>:<path>", of the sort used by ssh and rsync.
         """
 
-        # TODO: redirect error better, use something more portable than os.system
+        path = self.path
+
+        remote_path = self.remote_path
+        if self.host:
+            remote_path = ':'.join([self.host, remote_path])
+
+        # TODO: use something more portable than os.system
         if direction == 'push':
-            rsync_cmd = 'rsync -aq "%s/" "%s/" 2>/dev/null' % (self.path, self.remote_path)
+            rsync_cmd = 'rsync -ar "%s/" "%s/"' % (path, remote_path)
         elif direction == 'pull':
-            rsync_cmd = 'rsync -aq "%s/" "%s/" 2>/dev/null' % (self.remote_path, self.path)
+            rsync_cmd = 'rsync -ar "%s/" "%s/"' % (remote_path, path)
         else:
             raise RSyncException('invalid direction', direction)
 
@@ -340,6 +375,18 @@
         if rsync_rval != 0:
             raise RSyncException('rsync failure', (rsync_rval, rsync_cmd))
 
+    def touch(self):
+        if self.host:
+            host = self.host
+            touch_cmd = ('ssh %(host)s  "mkdir -p \'%(path)s\'"' % dict(host = self.host,
+                                                                        path = self.remote_path))
+        else:
+            touch_cmd = ("mkdir -p '%(path)s'" % dict(path = self.remote_path))
+        print "ECHO", touch_cmd
+        touch_rval = os.system(touch_cmd)
+        if 0 != touch_rval:
+            raise Exception('touch failure', (touch_rval, touch_cmd))
+
     def pull(self):
         return self.rsync('pull')
 
@@ -351,48 +398,48 @@
         self.push()
 
     def setup(self):
-        try:
-            self.pull()
-        except RSyncException:
-            # The experiment does not exist remotely, it's ok we will
-            # push it when we save.
-            pass
+        self.touch()
+        self.pull()
         super(RSyncChannel, self).setup()
 
+
 class DBRSyncChannel(RSyncChannel):
 
-    def __init__(self, db, tablename, root, remote_root):
-        super(DBRsyncChannel, self).__init__(root, remote_root, None, None)
-        self.db = db
-        self.tablename = tablename
+    def __init__(self, username, password, hostname, dbname, tablename, path, remote_root):
+        self.username, self.password, self.hostname, self.dbname, self.tablename \
+            = username, password, hostname, dbname, tablename
+
+        self.db = sql.postgres_serial(
+            user = self.username, 
+            password = self.password, 
+            host = self.hostname,
+            database = self.dbname,
+            table_prefix = self.tablename)
+
+        self.dbstate = sql.book_dct_postgres_serial(self.db)
+        if self.dbstate is None:
+            raise JobError(JobError.NOJOB,
+                           'No job was found to run.')
+
+        state = expand(self.dbstate)
+        experiment = resolve(state.dbdict.experiment)
+
+        remote_path = os.path.join(remote_root, self.dbname, self.tablename, str(self.dbstate.id))
+        super(DBRSyncChannel, self).__init__(path, remote_path, experiment, state)
 
     def save(self):
         super(DBRSyncChannel, self).save()
-        # sync to db
+        self.dbstate.update(flatten(self.state))
     
     def setup(self):
         # Extract a single experiment from the table that is not already running.
         # set self.experiment and self.state
         super(DBRSyncChannel, self).setup()
 
-
-################################################################################
-### multiple channels
-################################################################################
-
-class MultipleChannel(Channel):
-    def switch(self, job, *message):
-        raise NotImplementedError('This Channel does not allow switching between jobs.')
-    def broadcast(self, *message):
-        raise NotImplementedError()
-
-class SpawnChannel(MultipleChannel):
-    # spawns one process for each task
-    pass
-
-class GreenletChannel(MultipleChannel):
-    # uses a single process for all tasks, using greenlets to switch between them
-    pass
+    def run(self):
+        # We pass the force flag as True because the status flag is
+        # already set to RUNNING by book_dct in __init__
+        return super(DBRSyncChannel, self).run(force = True)
 
 
 
@@ -414,7 +461,7 @@
 
 runner_registry = dict()
 
-def cmdline(experiment, *strings):
+def runner_cmdline(experiment, *strings):
     """
     Start an experiment with parameters given on the command line.
 
@@ -442,12 +489,56 @@
             lr=0.03
     """
     state = expand(parse(*strings))
+    state.dbdict.experiment = experiment
     experiment = resolve(experiment)
-    #channel = RSyncChannel(os.getcwd(), os.path.realpath('yaddayadda'), experiment, state)
-    channel = StandardChannel(os.getcwd(), experiment, state)
+    #channel = RSyncChannel('.', 'yaddayadda', experiment, state)
+    channel = StandardChannel(format_d(state, sep=',', space = False),
+                              experiment, state)
     channel.run()
 
-runner_registry['cmdline'] = cmdline
+runner_registry['cmdline'] = runner_cmdline
+
+
+def runner_sqlschedule(dbdescr, experiment, *strings):
+    try:
+        username, password, hostname, dbname, tablename \
+            = sql.parse_dbstring(dbdescr)
+    except:
+        raise UsageError('Wrong syntax for dbdescr')
+
+    db = sql.postgres_serial(
+        user = username, 
+        password = password, 
+        host = hostname,
+        database = dbname,
+        table_prefix = tablename)
+
+    state = parse(*strings)
+    state['dbdict.experiment'] = experiment
+    sql.add_experiments_to_db([state], db, verbose = 1)
+
+runner_registry['sqlschedule'] = runner_sqlschedule
+
+
+
+def runner_sql(dbdescr, exproot):
+    try:
+        username, password, hostname, dbname, tablename \
+            = sql.parse_dbstring(dbdescr)
+    except:
+        raise UsageError('Wrong syntax for dbdescr')
+    workdir = tempfile.mkdtemp()
+    print 'wdir', workdir
+    channel = DBRSyncChannel(username, password, hostname, dbname, tablename,
+                             workdir,
+                             exproot)
+    channel.run()
+    shutil.rmtree(workdir, ignore_errors=True)
+
+runner_registry['sql'] = runner_sql
+
+    
+
 
 
 def help(topic = None):
@@ -458,6 +549,7 @@
     """
     if topic is None:
         print 'Available commands: (use help <command> for more info)'
+        print
         for name, command in sorted(runner_registry.iteritems()):
             print name.ljust(20), format_help(command).split('\n')[0]
         return
@@ -482,3 +574,43 @@
     run_cmdline()
 
 
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# fuck this shit
+
+# ################################################################################
+# ### multiple channels
+# ################################################################################
+
+# class MultipleChannel(Channel):
+#     def switch(self, job, *message):
+#         raise NotImplementedError('This Channel does not allow switching between jobs.')
+#     def broadcast(self, *message):
+#         raise NotImplementedError()
+
+# class SpawnChannel(MultipleChannel):
+#     # spawns one process for each task
+#     pass
+
+# class GreenletChannel(MultipleChannel):
+#     # uses a single process for all tasks, using greenlets to switch between them
+#     pass
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/dbdict/sql.py	Wed Dec 03 22:26:31 2008 -0500
@@ -0,0 +1,214 @@
+
+import sys
+
+import sqlalchemy
+from sqlalchemy import create_engine, desc
+from sqlalchemy.orm import eagerload
+import copy
+
+import psycopg2, psycopg2.extensions 
+
+from api0 import db_from_engine, postgres_db
+
+
+STATUS = 'dbdict.status'
+PRIORITY = 'dbdict.sql.priority'
+HOST = 'dbdict.sql.hostname'
+HOST_WORKDIR = 'dbdict.sql.host_workdir'
+PUSH_ERROR = 'dbdict.sql.push_error'
+
+START = 0
+RUNNING = 1
+DONE = 2
+
+_TEST_CONCURRENCY = False
+
+def postgres_serial(user, password, host, database, **kwargs):
+    """Return a DbHandle instance that communicates with a postgres database at transaction
+    isolation_level 'SERIALIZABLE'.
+
+    :param user: a username in the database
+    :param password: the password for the username
+    :param host: the network address of the host on which the postgres server is running
+    :param database: a database served by the postgres server
+    
+    """
+    this = postgres_serial
+
+    if not hasattr(this,'engine'):
+        def connect():
+            c = psycopg2.connect(user=user, password=password, database=database, host=host)
+            c.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE)
+            return c
+        pool_size = 0
+        this.engine = create_engine('postgres://'
+                ,creator=connect
+                ,pool_size=0 # should force the app release connections
+                )
+
+    db = db_from_engine(this.engine, **kwargs)
+    db._is_serialized_session_db = True
+    return db
+
+def book_dct_postgres_serial(db, retry_max_sleep=10.0, verbose=0):
+    """Find a trial in the lisa_db with status START.
+
+    A trial will be returned with dbdict_status=RUNNING.
+
+    Returns None if no such trial exists in DB.
+
+    This function uses a serial access to the lisadb to guarantee that no other
+    process will retrieve the same dct.  It is designed to facilitate writing
+    a "consumer" for a Producer-Consumer pattern based on the database.
+
+    """
+    print >> sys.stderr, """#TODO: use the priority field, not the status."""
+    print >> sys.stderr, """#TODO: ignore entries with key PUSH_ERROR."""
+
+    s = db._session
+
+    # NB. we need the query and attribute update to be in the same transaction
+    assert s.autocommit == False 
+
+    dcts_seen = set([])
+    keep_trying = True
+
+    dct = None
+    while (dct is None) and keep_trying:
+        #build a query
+        q = s.query(db._Dict)
+        q = q.options(eagerload('_attrs')) #hard-coded in api0
+        q = q.filter(db._Dict._attrs.any(name=STATUS, fval=START))
+
+        #try to reserve a dct
+        try:
+            #first() may raise psycopg2.ProgrammingError
+            dct = q.first()
+
+            if dct is not None:
+                assert (dct not in dcts_seen)
+                if verbose: print 'book_unstarted_dct retrieved, ', dct
+                dct._set_in_session(STATUS, RUNNING, s)
+                if 1:
+                    if _TEST_CONCURRENCY:
+                        print >> sys.stderr, 'SLEEPING BEFORE BOOKING'
+                        time.sleep(10)
+
+                    #commit() may raise psycopg2.ProgrammingError
+                    s.commit()
+                else:
+                    print >> sys.stderr, 'DEBUG MODE: NOT RESERVING JOB!', dct
+                #if we get this far, the job is ours!
+            else:
+                # no jobs are left
+                keep_trying = False
+        except (psycopg2.OperationalError,
+                sqlalchemy.exceptions.ProgrammingError), e:
+            #either the first() or the commit() raised
+            s.rollback() # docs say to do this (or close) after commit raises exception
+            if verbose: print 'caught exception', e
+            if dct:
+                # first() succeeded, commit() failed
+                dcts_seen.add(dct)
+                dct = None
+            wait = numpy.random.rand(1)*retry_max_sleep
+            if verbose: print 'another process stole our dct. Waiting %f secs' % wait
+            time.sleep(wait)
+    return dct
+
+def book_dct(db):
+    print >> sys.stderr, """#TODO: use the priority field, not the status."""
+    print >> sys.stderr, """#TODO: ignore entries with key self.push_error."""
+
+    return db.query(dbdict_status=START).first()
+
+def parse_dbstring(dbstring):
+    postgres = 'postgres://'
+    assert dbstring.startswith(postgres)
+    dbstring = dbstring[len(postgres):]
+
+    #username_and_password
+    colon_pos = dbstring.find('@')
+    username_and_password = dbstring[:colon_pos]
+    dbstring = dbstring[colon_pos+1:]
+
+    colon_pos = username_and_password.find(':')
+    if -1 == colon_pos:
+        username = username_and_password
+        password = None
+    else:
+        username = username_and_password[:colon_pos]
+        password = username_and_password[colon_pos+1:]
+    
+    #hostname
+    colon_pos = dbstring.find('/')
+    hostname = dbstring[:colon_pos]
+    dbstring = dbstring[colon_pos+1:]
+
+    #dbname
+    colon_pos = dbstring.find('/')
+    dbname = dbstring[:colon_pos]
+    dbstring = dbstring[colon_pos+1:]
+
+    #tablename
+    tablename = dbstring
+
+    if password is None:
+        password = open(os.getenv('HOME')+'/.dbdict_%s'%dbname).readline()[:-1]
+    if False:
+        print 'USERNAME', username
+        print 'PASS', password
+        print 'HOST', hostname
+        print 'DB', dbname
+        print 'TABLE', tablename
+
+    return username, password, hostname, dbname, tablename
+
+
+def add_experiments_to_db(jobs, db, verbose=0, add_dups=False, type_check=None):
+    """Add experiments paramatrized by jobs[i] to database db.
+
+    Default behaviour is to ignore jobs which are already in the database.
+
+    If type_check is a class (instead of None) then it will be used as a type declaration for
+    all the elements in each job dictionary.  For each key,value pair in the dictionary, there
+    must exist an attribute,value pair in the class meeting the following criteria:
+    the attribute and the key are equal, and the types of the values are equal.
+
+    :param jobs: The parameters of experiments to run.
+    :type jobs: an iterable object over dictionaries
+    :param verbose: print which jobs are added and which are skipped
+    :param add_dups: False will ignore a job if it matches (on all items()) with a db entry.
+    :type add_dups: Bool
+
+    :returns: list of (Bool,job[i]) in which the flags mean the corresponding job actually was
+    inserted.
+
+    """
+    rval = []
+    for job in jobs:
+        job = copy.copy(job)
+        do_insert = add_dups or (None is db.query(**job).first())
+
+        if do_insert:
+            if type_check:
+                for k,v in job.items():
+                    if type(v) != getattr(type_check, k):
+                        raise TypeError('Experiment contains value with wrong type',
+                                ((k,v), getattr(type_check, k)))
+
+            job[STATUS] = START
+            job[PRIORITY] = 1.0
+            if verbose:
+                print 'ADDING  ', job
+            db.insert(job)
+            rval.append((True, job))
+        else:
+            if verbose:
+                print 'SKIPPING', job
+            rval.append((False, job))
+
+
+
+
+