view doc/v2_planning/plugin_greenlet.py @ 1195:d3ee0d2d03e6

plugin_greenlet draft0
author James Bergstra <bergstrj@iro.umontreal.ca>
date Sun, 19 Sep 2010 13:05:48 -0400
parents
children e9bb3340a870
line wrap: on
line source

"""plugin_greenlet - draft of library architecture using greenlets"""

__license__ = None
__copyright__ = None

import copy, sys

import numpy
from greenlet import greenlet

def vm_unpack(incoming):
    # can't reliably distinguish between a kwargs-only switch and a switch with one dict
    # argument
    if incoming is None:
        rval = (), {}
    if isinstance(incoming, dict):
        rval = (), incoming
    elif isinstance(incoming, tuple):
        if (len(incoming)==2 
                and isinstance(incoming[0], tuple)
                and isinstance(incoming[1], dict)):
            rval = incoming
        else:
            rval = incoming, {}
    else:
        rval = (incoming,), {}
    #print 'unpack', incoming, rval
    return rval[0][0], rval[0][1], rval[0][2:], rval[1]

def unpack_from_vm(incoming):
    assert isinstance(incoming, tuple)
    assert len(incoming)==4
    return incoming

def vm_run(prog, *args, **kwargs):

    def vm_loop(gr, dest, a, kw):
        while True:
            if gr == 'return':
                return a, kw
            print 'vm_loop gr=',gr,'args=',a, 'kwargs=', kw
            gr, dest, a, kw = gr.switch(vm, gr, dest, a, kw)
            #print 'gmain incoming', incoming
    vm = greenlet(vm_loop)

    return vm.switch(prog, 'return', args, kwargs)


def seq(glets):
    return repeat(1, glets)

def repeat(N, glets):
    def repeat_task(vm, gself, dest, args, kwargs):
        while True:
            for i in xrange(N):
                for glet in glets:
                    print 'repeat_task_i dest=%(dest)s args=%(args)s, kw=%(kwargs)s'%locals()
                    # jump to task `glet`
                    # with instructions to report results back to this loop `g`
                    _vm, _gself, _dest, args, kwargs = vm.switch(glet, gself, args, kwargs)
                    assert _gself is gself
                    assert _dest is None # instructions can't tell us where to jump
            vm, gself, dest, args, kwargs = vm.switch(dest, None, args, kwargs)
    return greenlet(repeat_task)

def choose(which, options):
    raise NotImplementedError()

def weave(threads): 
    raise NotImplementedError()

def service(fn):
    """
    Create a greenlet whose first argument is the return-jump location.

    fn must accept as the first positional argument this greenlet itself, which can be used as
    the return-jump location for internal greenlet switches (ideally using gswitch).
    """
    def service_loop(vm, gself, dest, args, kwargs):
        while True:
            print 'service calling', fn.__name__, args, kwargs
            t = fn(vm, gself, *args, **kwargs)
            #TODO consider a protocol for returning args, kwargs
            if t is None:
                _vm,_gself,dest, args, kwargs = vm.switch(dest, None, (), {})
            else:
                _vm,_gself,dest, args, kwargs = vm.switch(dest, None, (t,), {})

            assert gself is _gself
    return greenlet(service_loop)

####################################################

class Dataset(object):
    def __init__(self, data):
        self.pos = 0
        self.data = data
    def next(self, vm, gself):
        rval = self.data[self.pos]
        self.pos += 1
        if self.pos == len(self.data):
            self.pos = 0
        return rval

class PCA_Analysis(object):
    def __init__(self):
        self.mean = 0
        self.eigvecs=0
        self.eigvals=0
    def analyze(self, me, X):
        self.mean = X.mean(axis=0)
        self.eigvecs=1
        self.eigvals=1
    def filt(self,me, X):
        return (self.X - self.mean) * self.eigvecs #TODO: divide by root eigvals?
    def pseudo_inverse(self, Y):
        return Y

class Layer(object):
    def __init__(self, w):
        self.w = w
    def filt(self, x):
        return self.w*x

def batches(src, N):
    # src is a service
    def rval(me):
        print 'batches src=', src, 'me=', me
        return numpy.asarray([gswitch(src, me)[0][0] for i in range(N)])
    return rval

def print_obj(vm, gself, obj):
    print obj
def no_op(*args, **kwargs):
    pass

def build_pca_trainer(data_src, pca_module, N):
    return greenlet(
                batches(
                    N=5,
                    src=inf_data,
                    dest=flow(pca_module.analyze,
                        dest=layer1_trainer)))

def main():
    dataset = Dataset(numpy.random.RandomState(123).randn(10,2))

    prog=repeat(3, [service(dataset.next),service(print_obj)])
    vm_run(prog)
    vm_run(prog)


def main_arch():

    # create components
    dataset = Dataset(numpy.random.RandomState(123).randn(10,2))
    pca_module = PCA_Analysis()
    layer1 = Layer(w=4)
    layer2 = Layer(w=3)
    kf = KFold(dataset, K=10)

    # create algorithm

    train_pca = seq([ np_batch(kf.next, 1000), pca.analyze])
    train_layer1 = repeat(100, [kf.next, pca.filt, cd1_update(layer1, lr=.01)])

    algo = repeat(10, [
        KFold.step,
        seq([train_pca,
            train_layer1,
            train_layer2,
            train_classifier,
            save_classifier,
            test_classifier]),
        KFold.set_score])

    gswitch(algo)


def main1():
    dataset = Dataset(numpy.random.RandomState(123).randn(10,2))
    pca_module = PCA_Analysis()

    # pca
    next_data = service(dataset.next)
    b5 = service(batches(src=next_data, N=5))
    print_pca_analyze = flow(pca_module.analyze, dest=sink(print_obj))

    # layer1_training
    layer1_training = driver(
            fn=cd1_trainer(layer1),
            srcs=[],
            )

    gswitch(b5, print_pca_analyze)
    
if __name__ == '__main__':
    sys.exit(main())



def flow(fn, dest):
    def rval(*args, **kwargs):
        while True:
            print 'flow calling', fn.__name__, args, kwargs
            t = fn(g, *args, **kwargs)
            args, kwargs = gswitch(dest, t)
    g = greenlet(rval)
    return g

def sink(fn):
    def rval(*args, **kwargs):
        return fn(g, *args, **kwargs)
    g = greenlet(rval)
    return g

def consumer(fn, src):
    def rval(*args, **kwargs):
        while True:
            fn(gswitch(src, *args, **kwargs))
    return greenlet(rval)