view doc/v2_planning/arch_src/plugin_JB.py @ 1321:ebcb76b38817

tinyimages - added main script to whiten patches
author James Bergstra <bergstrj@iro.umontreal.ca>
date Sun, 10 Oct 2010 13:43:53 -0400
parents 9fac28d80fb7
children
line wrap: on
line source

"""plugin_JB - draft of potential library architecture using iterators

This strategy makes use of a simple imperative language whose statements are python function
calls to create learning algorithms that can be manipulated and executed in several desirable
ways.  

The training procedure for a PCA module is easy to express:

    # allocate the relevant modules
    dataset = Dataset(numpy.random.RandomState(123).randn(13,1))
    pca = PCA_Analysis()
    pca_batchsize=1000
    
    reg = Registers()

    # define the control-flow of the algorithm
    train_pca = SEQ([
        REPEAT(pca_batchsize, CALL(dataset.next, store_to=reg('x'))), 
        CALL(pca.analyze, reg('x'))])

    # run the program
    train_pca.run()

The CALL, SEQ, and REPEAT are control-flow elements. The control-flow elements I
defined so far are:

- CALL - a basic statement, just calls a python function
- SEQ - a sequence of elements to run in order
- REPEAT - do something N times (and return None or maybe the last CALL?)
- LOOP - do something an infinite number of times
- CHOOSE - like a switch statement (should rename to SWITCH)
- WEAVE - interleave execution of multiple control-flow elements
- POPEN - launch a process and return its status when it's complete
- PRINT - a shortcut for CALL(print_obj)
- SPAWN - run a program fragment asynchronously in another process


We don't have many requirements per-se for the architecture, but I think this design respects
and realizes all of them.
The advantages of this approach are:

    - algorithms (including partially run ones) are COPYABLE, and SERIALIZABLE

    - algorithms can be executed without seizing control of the python process (the run()
      method does this, but if you look inside it you'll see it's a simple for loop)

      - it is easy to execute an algorithm step by step in a main loop that also checks for
        network or filesystem events related to e.g. job management.

    - the library can provide learning algorithms via control-flow templates, and the user can
      edit them (with search/replace calls) to include HOOKS, and DIAGNOSTIC plug-in
      functionality

      e.g. prog.find(CALL(cd1_update, layer=layer1)).replace_with(
          SEQ([CALL(cd1_update, layer=layer1), CALL(my_debugfn)]))

    - user can print the 'program code' of an algorithm built from library pieces

    - program can be optimized automatically.
      
      - e.g. BUFFER(N, CALL(dataset.next))  could be replaced if dataset.next implements the
        right attribute/protocol for 'bufferable' or something.

      - e.g. SEQ([a,b,c,d])  could be compiled to a single CALL to a Theano-compiled function
        if a, b, c, and d are calls to callable objects that export something like a
        'theano_SEQ' interface


"""

__license__ = 'TODO'
__copyright__ = 'TODO'

import cPickle, copy, os, subprocess, sys, time
import numpy

####################################################
# CONTROL-FLOW CONSTRUCTS

class INCOMPLETE: 
    """Return value for Element.step"""

class ELEMENT(object):
    """
    Base class for control flow elements (e.g. CALL, REPEAT, etc.)

    The design is that every element has a driver, that is another element, or the iterator
    implementation in the ELEMENT class.

    the driver calls start when entering a new control element
       - this would be called once per e.g. outer loop iteration

    the driver calls step to advance the control element
       - which returns INCOMPLETE
       - which returns any other object to indicate completion
    """

    # subclasses should override these methods:
    def start(self):
        pass
    def step(self):
        pass

    # subclasses should typically not override these:
    def run(self, n_steps=float('inf')):
        self.start()
        i = 0
        r = self.step()
        while r is INCOMPLETE:
            i += 1
            #TODO make sure there is not an off-by-one error
            if i > n_steps:
                break
            r = self.step()
        return r

class BUFFER_REPEAT(ELEMENT):
    """
    Accumulate a number of return values into one list / array.

    The source of return values `src` is a control element that will be restarted repeatedly in
    order to fulfil the requiement of gathering N samples.

    TODO: support accumulating of tuples of arrays
    """
    def __init__(self, N, src, storage=None):
        """
        TODO: use preallocated `storage`
        """
        self.N = N
        self.n = 0
        self.src = src
        self.storage = storage
        self.src.start(None)
        if self.storage != None:
            raise NotImplementedError()
    def start(self, arg):
        self.buf = [None] * self.N
        self.n = 0
        self.finished = False
    def step(self):
        assert not self.finished
        r = self.src.step()
        if r is INCOMPLETE:
            return r
        self.src.start(None) # restart our stream
        self.buf[self.n] = r
        self.n += 1
        if self.n == self.N:
            self.finished = True
            return self.buf
        else:
            return INCOMPLETE
        assert 0

class CALL(ELEMENT):
    """
    Control flow terminal - call a python function or method.

    Returns the return value of the call.
    """
    def __init__(self, fn, *args, **kwargs):
        self.fn = fn
        self.args = args
        self.kwargs=kwargs
    def start(self):
        self.finished = False
        return self
    def step(self):
        assert not self.finished
        self.finished = True
        fn_rval = self.fn(*self.lookup_args(), **self.lookup_kwargs())
        if '_set' in self.kwargs:
            self.kwargs['_set'].set(fn_rval)
    def __getstate__(self):
        rval = dict(self.__dict__)
        if type(self.fn) is type(self.step): #instancemethod
            fn = rval.pop('fn')
            rval['i fn'] = fn.im_func, fn.im_self, fn.im_class
        return rval
    def __setstate__(self, dct):
        if 'i fn' in dct:
            dct['fn'] = type(self.step)(*dct.pop('i fn'))
        self.__dict__.update(dct)

    def lookup_args(self):
        rval = []
        for a in self.args:
            if isinstance(a, Register):
                rval.append(a.get())
            else:
                rval.append(a)
        return rval
    def lookup_kwargs(self):
        rval = {}
        for k,v in self.kwargs.iteritems():
            if k == '_set':
                continue
            if isinstance(v, Register):
                rval[k] = v.get()
            else:
                rval[k] = v
        return rval

def CHOOSE(which, options):
    """
    Execute one out of a number of optional control flow paths
    """
    raise NotImplementedError()

def LOOP(element):
    #TODO: implement a true infinite loop
    return REPEAT(sys.maxint, element)

class REPEAT(ELEMENT):
    def __init__(self, N, element, counter=None):
        self.N = N
        if not isinstance(element, ELEMENT):
            raise TypeError(element)
        self.element = element
        self.counter = counter

    #TODO: check for N being callable
    def start(self):
        self.n = 0   #loop iteration
        self.finished = False
        self.element.start()
        if self.counter:
            self.counter.set(0)

    def step(self):
        assert not self.finished
        r = self.element.step()
        if r is INCOMPLETE:
            return INCOMPLETE
        self.n += 1
        if self.counter:
            self.counter.set(self.n)
        if self.n < self.N:
            self.element.start()
            return INCOMPLETE
        else:
            self.finished = True
            return r

class SEQ(ELEMENT):
    def __init__(self, elements):
        self.elements = list(elements)
    def start(self):
        if len(self.elements):
            self.elements[0].start()
            self.pos = 0
        self.finished = False
    def step(self):
        if self.pos == len(self.elements):
            self.finished=True
            return
        r = self.elements[self.pos].step()
        if r is INCOMPLETE:
            return r
        self.pos += 1
        if self.pos < len(self.elements):
            self.elements[self.pos].start()
        return INCOMPLETE

class WEAVE(ELEMENT):
    """
    Interleave execution of a number of elements.

    TODO: allow a schedule (at least relative frequency) of elements from each program
    """
    def __init__(self, n_required, elements):
        self.elements = elements
        if n_required == -1:
            self.n_required = len(elements)
        else:
            self.n_required = n_required
    def start(self):
        for el in self.elements:
            el.start()
        self.elem_finished = [0] * len(self.elements)
        self.idx = 0
        self.finished= False 
    def step(self):
        assert not self.finished # if this is triggered, we have a broken driver

        #start with this check in case there were no elements
        # it's possible for the number of finished elements to exceed the threshold
        if sum(self.elem_finished) >= self.n_required:
            self.finished = True
            return None

        # step the active element
        r = self.elements[self.idx].step()

        if r is not INCOMPLETE:
            self.elem_finished[self.idx] = True

            # check for completion
            if sum(self.elem_finished) >= self.n_required:
                self.finished = True
                return None

        # advance to the next un-finished element
        self.idx = (self.idx+1) % len(self.elements)
        while self.elem_finished[self.idx]:
            self.idx = (self.idx+1) % len(self.elements)

        return INCOMPLETE

class POPEN(ELEMENT):
    def __init__(self, args):
        self.args = args
    def start(self):
        self.p = subprocess.Popen(self.args)
    def step(self):
        r = self.p.poll() 
        if r is None:
            return INCOMPLETE
        return r

def PRINT(obj):
    return CALL(print_obj, obj)

class SPAWN(ELEMENT):
    SUCCESS = 0
    def __init__(self, data, prog):
        self.data = data
        self.prog = prog
    def start(self):
        # pickle the (data, prog) pair
        s = cPickle.dumps((self.data, self.prog))

        # call python with a stub function that
        # unpickles the data, prog pair and starts running the prog
        self.rpipe, wpipe = os.pipe()
        code = 'import sys, plugin_JB; sys.exit(plugin_JB.SPAWN._main(%i))'%wpipe
        self.p = subprocess.Popen(
                ['python', '-c', code], 
                stdin=subprocess.PIPE)
        # send the data and prog to the other process
        self.p.stdin.write(s)
        self.finished= False

        #TODO: send over tgz of the modules this code needs

        #TODO: When the client process is on a different machine, negotiate with the client
        # process to determine which modules it needs, and send over the code for pure python
        # ones.  Make sure versions match for non-pure python ones.

    def step(self):
        assert not self.finished
        r = self.p.poll() 
        if r is None:
            return INCOMPLETE    # typical exit case
        self.finished = True
        if r != self.SUCCESS:
            print "UH OH", r # TODO - ???
        rfile = os.fdopen(self.rpipe)
        # recv the revised of the data dictionary
        data = cPickle.load(rfile)
        # modify the data dict in-place
        # for new values to be visible to other components
        self.data.update(data)
        rfile.close()
        #TODO: return something meaningful? like r?
        return None

    @staticmethod
    def _main(wpipe):
        #TODO: unpack and install tgz of the modules this code needs
        data, prog = cPickle.load(sys.stdin)
        rval = prog.run()
        os.write(wpipe, cPickle.dumps(data))
        return SPAWN.SUCCESS
        #os.close(wpipe)

class Register(object):
    def __init__(self, registers, key):
        self.registers = registers
        self.key = key
    def set(self, val):
        self.registers[self.key] = val
    def get(self):
        return self.registers[self.key]
class Registers(dict):
    def __call__(self, key):
        return Register(self, key)

def print_obj(obj):
    print obj
def print_obj_attr(obj, attr):
    print getattr(obj, attr)
def no_op(*args, **kwargs):
    pass

def importable_fn(d):
    d['new key'] = len(d)


if __name__  == '__main__':
    print 'this is the library file, run "python plugin_JB_main.py"'