view pylearn/dbdict/newstuff.py @ 571:13bc6620ad95

removed references to file locking packages
author Olivier Breuleux <breuleuo@iro.umontreal.ca>
date Wed, 03 Dec 2008 22:29:40 -0500
parents 1f036d934ad9
children 9f5891cd4048
line wrap: on
line source


from __future__ import with_statement

from collections import defaultdict
import re, sys, inspect, os, signal, tempfile, shutil


import sql


################################################################################
### misc
################################################################################

class DD(defaultdict):
    def __getattr__(self, attr):
        return self[attr]
    def __setattr__(self, attr, value):
        self[attr] = value
    def __str__(self):
        return 'DD%s' % dict(self)
    def __repr__(self):
        return str(self)

################################################################################
### resolve
################################################################################

def resolve(name):
    symbols = name.split('.')
    builder = __import__(symbols[0])
    for sym in symbols[1:]:
        builder = getattr(builder, sym)
    return builder

################################################################################
### dictionary
################################################################################

def convert(obj):
    return eval(obj)
#     if not isinstance(obj, str):
#         return obj
#     def kw(x):
#         x = x.lower()
#         return dict(true = True,
#                     false = False,
#                     none = None)[x]
#     for f in (kw, int, float):
#         try:
#             return f(obj)
#         except:
#             pass
#     return obj

def flatten(obj):
    d = {}
    def helper(d, prefix, obj):
        if isinstance(obj, (str, int, float)):
            d[prefix] = obj #convert(obj)
        else:
            if isinstance(obj, dict):
                subd = obj
            else:
                subd = obj.state()
                subd['__builder__'] = '%s.%s' % (obj.__module__, obj.__class__.__name__)
            for k, v in subd.iteritems():
                pfx = '.'.join([prefix, k]) if prefix else k
                helper(d, pfx, v)
    helper(d, '', obj)
    return d

def expand(d):
    def dd():
        return DD(dd)
    struct = dd()
    for k, v in d.iteritems():
        if k == '':
            raise NotImplementedError()
        else:
            keys = k.split('.')
        current = struct
        for k2 in keys[:-1]:
            current = current[k2]
        current[keys[-1]] = v #convert(v)
    return struct

def realize(d):
    if not isinstance(d, dict):
        return d
    d = dict((k, realize(v)) for k, v in d.iteritems())
    if '__builder__' in d:
        builder = resolve(d.pop('__builder__'))
        return builder(**d)
    return d

def make(d):
    return realize(expand(d))

################################################################################
### errors
################################################################################

class UsageError(Exception):
    pass

################################################################################
### parsing and formatting
################################################################################

def parse(*strings):
    d = {}
    for string in strings:
        s1 = re.split(' *= *', string, 1)
        s2 = re.split(' *:: *', string, 1)
        if len(s1) == 1 and len(s2) == 1:
            raise UsageError('Expected a keyword argument in place of "%s"' % s1[0])
        elif len(s1) == 2:
            k, v = s1
            v = convert(v)
        elif len(s2) == 2:
            k, v = s2
            k += '.__builder__'
        d[k] = v
    return d

def format_d(d, sep = '\n', space = True):
    d = flatten(d)
    pattern = "%s = %r" if space else "%s=%r"
    return sep.join(pattern % (k, v) for k, v in d.iteritems())

def format_help(topic):
    if topic is None:
        return 'No help.'
    elif hasattr(topic, 'help'):
        help = topic.help()
    else:
        help = topic.__doc__
    if not help:
        return 'No help.'

    ss = map(str.rstrip, help.split('\n'))
    try:
        baseline = min([len(line) - len(line.lstrip()) for line in ss if line])
    except:
        return 'No help.'
    s = '\n'.join([line[baseline:] for line in ss])
    s = re.sub(string = s, pattern = '\n{2,}', repl = '\n\n')
    s = re.sub(string = s, pattern = '(^\n*)|(\n*$)', repl = '')

    return s

################################################################################
### single channels
################################################################################

# try:
#     import greenlet
# except:
#     try:
#         from py import greenlet
#     except:
#         print >>sys.stderr, 'the greenlet module is unavailable'
#         greenlet = None


class Channel(object):

    COMPLETE = None
    INCOMPLETE = True
    
    START = 0
    """dbdict.status == START means a experiment is ready to run"""
    RUNNING = 1
    """dbdict.status == RUNNING means a experiment is running on dbdict_hostname"""
    DONE = 2
    """dbdict.status == DONE means a experiment has completed (not necessarily successfully)"""

    # Methods to be used by the experiment to communicate with the channel

    def save(self):
        """
        Save the experiment's state to the various media supported by
        the Channel.
        """
        raise NotImplementedError()

    def switch(self, message = None):
        """
        Called from the experiment to give the control back to the channel.
        The following return values are meaningful:
          * 'stop' -> the experiment must stop as soon as possible. It may save what
            it needs to save. This occurs when SIGTERM or SIGINT are sent (or in
            user-defined circumstances).
        switch() may give the control to the user. In this case, the user may
        resume the experiment by calling switch() again. If an argument is given
        by the user, it will be relayed to the experiment.
        """
        pass

    def __call__(self, message = None):
        return self.switch(message)

    def save_and_switch(self):
        self.save()
        self.switch()

    # Methods to run the experiment

    def setup(self):
        pass

    def __enter__(self):
        pass

    def __exit__(self):
        pass

    def run(self):
        pass



class JobError(Exception):
    RUNNING = 0
    DONE = 1
    NOJOB = 2


class SingleChannel(Channel):

    def __init__(self, experiment, state):
        self.experiment = experiment
        self.state = state
        self.feedback = None

    def switch(self, message = None):
        feedback = self.feedback
        self.feedback = None
        return feedback

    def run(self, force = False):
        self.setup()

        status = self.state.dbdict.get('status', self.START)
        if status is self.DONE and not force:
            # If you want to disregard this, use the --force flag (not yet implemented)
            raise JobError(JobError.RUNNING,
                           'The job has already completed.')
        elif status is self.RUNNING and not force:
            raise JobError(JobError.DONE,
                           'The job is already running.')
        self.state.dbdict.status = self.RUNNING

        v = self.INCOMPLETE
        with self:
            try:
                v = self.experiment(self, self.state)
            finally:
                self.state.dbdict.status = self.DONE if v is self.COMPLETE else self.START

        return v

    def on_sigterm(self, signo, frame):
        # SIGTERM handler. It is the experiment function's responsibility to
        # call switch() often enough to get this feedback.
        self.feedback = 'stop'

    def __enter__(self):
        # install a SIGTERM handler that asks the experiment function to return
        # the next time it will call switch()
        self.prev_sigterm = signal.getsignal(signal.SIGTERM)
        self.prev_sigint = signal.getsignal(signal.SIGINT)
        signal.signal(signal.SIGTERM, self.on_sigterm)
        signal.signal(signal.SIGINT, self.on_sigterm)
        return self

    def __exit__(self, type, value, traceback):
        signal.signal(signal.SIGTERM, self.prev_sigterm)
        signal.signal(signal.SIGINT, self.prev_sigint)
        self.prev_sigterm = None
        self.prev_sigint = None
        self.save()


class StandardChannel(SingleChannel):

    def __init__(self, path, experiment, state, redirect_stdout = False, redirect_stderr = False):
        super(StandardChannel, self).__init__(experiment, state)
        self.path = os.path.realpath(path)
        self.redirect_stdout = redirect_stdout
        self.redirect_stderr = redirect_stderr

    def save(self):
        with open(os.path.join(self.path, 'current.conf'), 'w') as current:
            current.write(format_d(self.state))
            current.write('\n')

    def __enter__(self):
        self.old_cwd = os.getcwd()
        os.chdir(self.path)
        if self.redirect_stdout:
            self.old_stdout = sys.stdout
            sys.stdout = open('stdout', 'a')
        if self.redirect_stderr:
            self.old_stderr = sys.stderr
            sys.stderr = open('stderr', 'a')
        return super(StandardChannel, self).__enter__()

    def __exit__(self, type, value, traceback):
        if self.redirect_stdout:
            sys.stdout.close()
            sys.stdout = self.old_stdout
        if self.redirect_stderr:
            sys.stderr.close()
            sys.stderr = self.old_stderr
        os.chdir(self.old_cwd)
        return super(StandardChannel, self).__exit__(type, value, traceback)

    def setup(self):
        if not os.path.isdir(self.path):
            os.makedirs(self.path)
        with self:
            origf = os.path.join(self.path, 'orig.conf')
            if not os.path.isfile(origf):
                with open(origf, 'w') as orig:
                    orig.write(format_d(self.state))
                    orig.write('\n')
            currentf = os.path.join(self.path, 'current.conf')
            if os.path.isfile(currentf):
                with open(currentf, 'r') as current:
                    self.state = expand(parse(*map(str.strip, current.readlines())))


class RSyncException(Exception):
    pass

class RSyncChannel(StandardChannel):

    def __init__(self, path, remote_path, experiment, state):
        super(RSyncChannel, self).__init__(path, experiment, state)

        ssh_prefix='ssh://'
        if remote_path.startswith(ssh_prefix):
            remote_path = remote_path[len(ssh_prefix):]
            colon_pos = remote_path.find(':')
            self.host = remote_path[:colon_pos]
            self.remote_path = remote_path[colon_pos+1:]
        else:
            self.host = ''
            self.remote_path = os.path.realpath(remote_path)

    def rsync(self, direction):
        """The directory at which experiment-related files are stored.
        """

        path = self.path

        remote_path = self.remote_path
        if self.host:
            remote_path = ':'.join([self.host, remote_path])

        # TODO: use something more portable than os.system
        if direction == 'push':
            rsync_cmd = 'rsync -ar "%s/" "%s/"' % (path, remote_path)
        elif direction == 'pull':
            rsync_cmd = 'rsync -ar "%s/" "%s/"' % (remote_path, path)
        else:
            raise RSyncException('invalid direction', direction)

        rsync_rval = os.system(rsync_cmd)
        if rsync_rval != 0:
            raise RSyncException('rsync failure', (rsync_rval, rsync_cmd))

    def touch(self):
        if self.host:
            host = self.host
            touch_cmd = ('ssh %(host)s  "mkdir -p \'%(path)s\'"' % dict(host = self.host,
                                                                        path = self.remote_path))
        else:
            touch_cmd = ("mkdir -p '%(path)s'" % dict(path = self.remote_path))
        print "ECHO", touch_cmd
        touch_rval = os.system(touch_cmd)
        if 0 != touch_rval:
            raise Exception('touch failure', (touch_rval, touch_cmd))

    def pull(self):
        return self.rsync('pull')

    def push(self):
        return self.rsync('push')

    def save(self):
        super(RSyncChannel, self).save()
        self.push()

    def setup(self):
        self.touch()
        self.pull()
        super(RSyncChannel, self).setup()


class DBRSyncChannel(RSyncChannel):

    def __init__(self, username, password, hostname, dbname, tablename, path, remote_root):
        self.username, self.password, self.hostname, self.dbname, self.tablename \
            = username, password, hostname, dbname, tablename

        self.db = sql.postgres_serial(
            user = self.username, 
            password = self.password, 
            host = self.hostname,
            database = self.dbname,
            table_prefix = self.tablename)

        self.dbstate = sql.book_dct_postgres_serial(self.db)
        if self.dbstate is None:
            raise JobError(JobError.NOJOB,
                           'No job was found to run.')

        state = expand(self.dbstate)
        experiment = resolve(state.dbdict.experiment)

        remote_path = os.path.join(remote_root, self.dbname, self.tablename, str(self.dbstate.id))
        super(DBRSyncChannel, self).__init__(path, remote_path, experiment, state)

    def save(self):
        super(DBRSyncChannel, self).save()
        self.dbstate.update(flatten(self.state))
    
    def setup(self):
        # Extract a single experiment from the table that is not already running.
        # set self.experiment and self.state
        super(DBRSyncChannel, self).setup()

    def run(self):
        # We pass the force flag as True because the status flag is
        # already set to RUNNING by book_dct in __init__
        return super(DBRSyncChannel, self).run(force = True)



################################################################################
### running
################################################################################

def run(type, arguments):
    runner = runner_registry.get(type, None)
    if not runner:
        raise UsageError('Unknown runner: "%s"' % type)
    argspec = inspect.getargspec(runner)
    minargs = len(argspec[0])-(len(argspec[3]) if argspec[3] else 0)
    maxargs = len(argspec[0])
    if minargs > len(arguments) or maxargs < len(arguments) and not argspec[1]:
        s = format_help(runner)
        raise UsageError(s)
    runner(*arguments)

runner_registry = dict()

def runner_cmdline(experiment, *strings):
    """
    Start an experiment with parameters given on the command line.

    Usage: cmdline <experiment> <prop1::type> <prop1=value1> <prop2=value2> ...

    Run an experiment with parameters provided on the command
    line.  The symbol described by <experiment> will be imported
    using the normal python import rules and will be called with
    the dictionary described on the command line.

    The signature of the function located at <experiment> must
    look like:
        def my_experiment(state, channel):
            ...

    Examples of setting parameters:
        a=2 => state['a'] = 2
        b.c=3 => state['b']['c'] = 3
        p::mymodule.Something => state['p']['__builder__']=mymodule.Something

    Example call:
        run_experiment cmdline mymodule.my_experiment \\
            stopper::pylearn.stopper.nsteps \\ # use pylearn.stopper.nsteps
            stopper.n=10000 \\ # the argument "n" of nsteps is 10000
            lr=0.03
    """
    state = expand(parse(*strings))
    state.dbdict.experiment = experiment
    experiment = resolve(experiment)
    #channel = RSyncChannel('.', 'yaddayadda', experiment, state)
    channel = StandardChannel(format_d(state, sep=',', space = False),
                              experiment, state)
    channel.run()

runner_registry['cmdline'] = runner_cmdline


def runner_sqlschedule(dbdescr, experiment, *strings):
    try:
        username, password, hostname, dbname, tablename \
            = sql.parse_dbstring(dbdescr)
    except:
        raise UsageError('Wrong syntax for dbdescr')

    db = sql.postgres_serial(
        user = username, 
        password = password, 
        host = hostname,
        database = dbname,
        table_prefix = tablename)

    state = parse(*strings)
    state['dbdict.experiment'] = experiment
    sql.add_experiments_to_db([state], db, verbose = 1)

runner_registry['sqlschedule'] = runner_sqlschedule



def runner_sql(dbdescr, exproot):
    try:
        username, password, hostname, dbname, tablename \
            = sql.parse_dbstring(dbdescr)
    except:
        raise UsageError('Wrong syntax for dbdescr')
    workdir = tempfile.mkdtemp()
    print 'wdir', workdir
    channel = DBRSyncChannel(username, password, hostname, dbname, tablename,
                             workdir,
                             exproot)
    channel.run()
    shutil.rmtree(workdir, ignore_errors=True)

runner_registry['sql'] = runner_sql

    



def help(topic = None):
    """
    Get help for a topic.

    Usage: help <topic>
    """
    if topic is None:
        print 'Available commands: (use help <command> for more info)'
        print
        for name, command in sorted(runner_registry.iteritems()):
            print name.ljust(20), format_help(command).split('\n')[0]
        return
    print format_help(runner_registry.get(topic, None))

runner_registry['help'] = help

################################################################################
### main
################################################################################

def run_cmdline():
    try:
        if len(sys.argv) <= 1:
            raise UsageError('Usage: %s <run_type> [<arguments>*]' % sys.argv[0])
        run(sys.argv[1], sys.argv[2:])
    except UsageError, e:
        print 'Usage error:'
        print e

if __name__ == '__main__':
    run_cmdline()























# fuck this shit

# ################################################################################
# ### multiple channels
# ################################################################################

# class MultipleChannel(Channel):
#     def switch(self, job, *message):
#         raise NotImplementedError('This Channel does not allow switching between jobs.')
#     def broadcast(self, *message):
#         raise NotImplementedError()

# class SpawnChannel(MultipleChannel):
#     # spawns one process for each task
#     pass

# class GreenletChannel(MultipleChannel):
#     # uses a single process for all tasks, using greenlets to switch between them
#     pass