view pylearn/dbdict/ @ 622:d2d582bcf7dc

api0 and sql seem to work with session closing
author James Bergstra <>
date Sun, 18 Jan 2009 21:39:28 -0500
parents 7041749cf804
children ae954c27fd11
line wrap: on
line source

import sys, os, copy, time

import numpy.random

import sqlalchemy
import sqlalchemy.pool
from sqlalchemy import create_engine, desc
from sqlalchemy.orm import eagerload
import psycopg2, psycopg2.extensions 

from api0 import db_from_engine, postgres_db, DbHandle

EXPERIMENT = 'dbdict.experiment'
#using the dictionary to store these is too slow
STATUS = 'dbdict.status'
PRIORITY = 'dbdict.sql.priority'

HOST = 'dbdict.sql.hostname'
HOST_WORKDIR = 'dbdict.sql.host_workdir'
PUSH_ERROR = 'dbdict.sql.push_error'

DONE = 2


def postgres_serial(user, password, host, database, poolclass=sqlalchemy.pool.NullPool, **kwargs):
    """Return a DbHandle instance that communicates with a postgres database at transaction
    isolation_level 'SERIALIZABLE'.

    :param user: a username in the database
    :param password: the password for the username
    :param host: the network address of the host on which the postgres server is running
    :param database: a database served by the postgres server
    this = postgres_serial

    if not hasattr(this,'engine'):
        def connect():
            c = psycopg2.connect(user=user, password=password, database=database, host=host)
            return c
        this.engine = create_engine('postgres://'

    db = db_from_engine(this.engine, **kwargs)
    db._is_serialized_session_db = True
    return db

def book_dct_postgres_serial(db, retry_max_sleep=10.0, verbose=0):
    """Find a trial in the lisa_db with status START.

    A trial will be returned with status=RUNNING.

    Returns None if no such trial exists in DB.

    This function uses a serial access to the lisadb to guarantee that no other
    process will retrieve the same dct.  It is designed to facilitate writing
    a "consumer" for a Producer-Consumer pattern based on the database.

    print >> sys.stderr, """#TODO: use the priority field, not the status."""
    print >> sys.stderr, """#TODO: ignore entries with key PUSH_ERROR."""

    s = db.session() #open a new session

    # NB. we need the query and attribute update to be in the same transaction
    assert s.autocommit == False 

    dcts_seen = set([])
    keep_trying = True

    dct = None
    while (dct is None) and keep_trying:
        #build a query
        q = s.query(db._Dict)
        #q = q.options(eagerload('_attrs')) #hard-coded in api0

        # use dedicated column to retrieve jobs, not the dictionary keyval pair
        # This should be much faster.
        q = q.filter(db._Dict.status==START)

        #try to reserve a dct
            #first() may raise psycopg2.ProgrammingError
            dct = q.first()

            if dct is not None:
                assert (dct not in dcts_seen)
                if verbose: print 'book_unstarted_dct retrieved, ', dct
                dct._set_in_session(STATUS, RUNNING, s)
                if 1:
                    if _TEST_CONCURRENCY:
                        print >> sys.stderr, 'SLEEPING BEFORE BOOKING'

                    #commit() may raise psycopg2.ProgrammingError
                    print >> sys.stderr, 'DEBUG MODE: NOT RESERVING JOB!', dct
                #if we get this far, the job is ours!
                # no jobs are left
                keep_trying = False
        except (psycopg2.OperationalError,
                sqlalchemy.exc.DBAPIError), e:
            #either the first() or the commit() raised
            s.rollback() # docs say to do this (or close) after commit raises exception
            if verbose: print 'caught exception', e
            if dct:
                # first() succeeded, commit() failed
                dct = None
            wait = numpy.random.rand(1)*retry_max_sleep
            if verbose: print 'another process stole our dct. Waiting %f secs' % wait
    return dct

def book_dct_non_postgres(db):
    print >> sys.stderr, """#TODO: use the priority field, not the status."""
    print >> sys.stderr, """#TODO: ignore entries with key self.push_error."""

    raise NotImplementedError()

# Connect

def parse_dbstring(dbstring):
    postgres = 'postgres://'
    if not dbstring.startswith(postgres):
        raise ValueError('For now, dbdict dbstrings must start with postgres://', dbstring)
    dbstring = dbstring[len(postgres):]

    colon_pos = dbstring.find('@')
    username_and_password = dbstring[:colon_pos]
    dbstring = dbstring[colon_pos+1:]

    colon_pos = username_and_password.find(':')
    if -1 == colon_pos:
        username = username_and_password
        password = None
        username = username_and_password[:colon_pos]
        password = username_and_password[colon_pos+1:]
    colon_pos = dbstring.find('/')
    hostname = dbstring[:colon_pos]
    dbstring = dbstring[colon_pos+1:]

    colon_pos = dbstring.find('/')
    dbname = dbstring[:colon_pos]
    dbstring = dbstring[colon_pos+1:]

    tablename = dbstring

    if password is None:
        password = get_password(hostname, dbname)

    if False:
        print 'USERNAME', username
        print 'PASS', password
        print 'HOST', hostname
        print 'DB', dbname
        print 'TABLE', tablename

    return username, password, hostname, dbname, tablename

def get_password(hostname, dbname):
    """Return the current user's password for a given database

    :TODO: Replace this mechanism with a section in the pylearn configuration file
    password_path = os.getenv('HOME')+'/.dbdict_%s'%dbname
        password = open(password_path).readline()[:-1] #cut the trailing newline
        raise ValueError( 'Failed to read password for db "%s" from %s' % (dbname, password_path))
    return password

def db(dbstring):
    username, password, hostname, dbname, tablename = parse_dbstring(dbstring)
        return postgres_db(username, password, hostname, dbname, table_prefix=tablename)
        print 'Error connecting with password', password

# Queue

def insert_dict(jobdict, db, force_dup=False, session=None):
    """Insert a new `job` dictionary into database `db`.

    :param force_dup: forces insertion even if an identical dictionary is already in the db

    job = copy.copy(jobdict)
    if session is None:
        s = db.session()
        do_insert = force_dup or (None is db.query(s).filter_eq_dct(job).first())
        do_insert = force_dup or (None is db.query(session).filter_eq_dct(job).first())
    if do_insert:
        job[STATUS] = START
        job[PRIORITY] = 1.0
        return db.insert(job, session=session)
        return None

def insert_job(experiment_fn, state, db, force_dup=False, session=None):
    state = copy.copy(state)
    state[EXPERIMENT] = experiment_fn.__module__ + '.' + experiment_fn.__name__
    return insert_dict(state, db, force_dup=force_dup, session=session)

def add_experiments_to_db(jobs, db, verbose=0, add_dups=False, type_check=None, session=None):
    """Add experiments paramatrized by jobs[i] to database db.

    Default behaviour is to ignore jobs which are already in the database.

    If type_check is a class (instead of None) then it will be used as a type declaration for
    all the elements in each job dictionary.  For each key,value pair in the dictionary, there
    must exist an attribute,value pair in the class meeting the following criteria:
    the attribute and the key are equal, and the types of the values are equal.

    :param jobs: The parameters of experiments to run.
    :type jobs: an iterable object over dictionaries
    :param verbose: print which jobs are added and which are skipped
    :param add_dups: False will ignore a job if it matches (on all items()) with a db entry.
    :type add_dups: Bool

    :returns: list of (Bool,job[i]) in which the flags mean the corresponding job actually was

    rval = []
    for job in jobs:
        job = copy.copy(job)
        if session is None:
            s = db.session()
            do_insert = force_dup or (None is db.query(s).filter_eq_dct(job).first())
            do_insert = force_dup or (None is db.query(session).filter_eq_dct(job).first())

        if do_insert:
            if type_check:
                for k,v in job.items():
                    if type(v) != getattr(type_check, k):
                        raise TypeError('Experiment contains value with wrong type',
                                ((k,v), getattr(type_check, k)))

            job[STATUS] = START
            job[PRIORITY] = 1.0
            if verbose:
                print 'ADDING  ', job
            rval.append((True, job))
            if verbose:
                print 'SKIPPING', job
            rval.append((False, job))