view pylearn/dbdict/tools.py @ 540:85d3300c9a9c

m
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 13 Nov 2008 17:54:56 -0500
parents 798607a058bd
children ee5324c21e60
line wrap: on
line source

import sys

from .experiment  import COMPLETE, INCOMPLETE

MODULE = 'dbdict_module'
SYMBOL = 'dbdict_symbol'
PREIMPORT = 'dbdict_preimport'

def dummy_channel(*args, **kwargs):
    return None

#
#this proxy object lets experiments use a dict like a state object
#
def DictProxyState(dct):
    defaults_obj = [None]
    class Proxy(object):
        def subdict(s, prefix=''):
            rval = {}
            for k,v in dct.items():
                if k.startswith(prefix):
                    rval[k[len(prefix):]] = v
            return rval
        def use_defaults(s, obj):
            defaults_obj[0] = obj

        def __getitem__(s,a):
            try:
                return dct[a]
            except Exception, e:
                try:
                    return getattr(defaults_obj[0], a)
                except:
                    raise e

        def __setitem__(s,a,v):
            dct[a] = v

        def __getattr__(s,a):
            try:
                return dct[a]
            except KeyError:
                return getattr(defaults_obj[0], a)
        def __setattr__(s,a,v):
            dct[a] = v
    return Proxy()

def load_state_fn(state):

    #
    # load the experiment class 
    #
    dbdict_module_name = getattr(state,MODULE)
    dbdict_symbol = getattr(state, SYMBOL)

    preimport_list = getattr(state, PREIMPORT, "").split()
    for preimport in preimport_list:
        __import__(preimport, fromlist=[None], level=0)

    try:
        dbdict_module = __import__(dbdict_module_name, fromlist=[None], level=0)
        dbdict_fn = getattr(dbdict_module, dbdict_symbol)
    except:
        print >> sys.stderr, "FAILED to load job symbol:", dbdict_module_name, dbdict_symbol
        print >> sys.stderr, "PATH", sys.path
        raise
    
    return dbdict_fn


def run_state(state, channel = dummy_channel):
    fn = load_state_fn(state)
    rval = fn(state, channel) 
    if rval not in (COMPLETE, INCOMPLETE):
        print >> sys.stderr, "WARNING: INVALID job function return value"
    return rval