Mercurial > pylearn
view doc/v2_planning/arch_src/plugin_JB.py @ 1321:ebcb76b38817
tinyimages - added main script to whiten patches
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Sun, 10 Oct 2010 13:43:53 -0400 |
parents | 9fac28d80fb7 |
children |
line wrap: on
line source
"""plugin_JB - draft of potential library architecture using iterators This strategy makes use of a simple imperative language whose statements are python function calls to create learning algorithms that can be manipulated and executed in several desirable ways. The training procedure for a PCA module is easy to express: # allocate the relevant modules dataset = Dataset(numpy.random.RandomState(123).randn(13,1)) pca = PCA_Analysis() pca_batchsize=1000 reg = Registers() # define the control-flow of the algorithm train_pca = SEQ([ REPEAT(pca_batchsize, CALL(dataset.next, store_to=reg('x'))), CALL(pca.analyze, reg('x'))]) # run the program train_pca.run() The CALL, SEQ, and REPEAT are control-flow elements. The control-flow elements I defined so far are: - CALL - a basic statement, just calls a python function - SEQ - a sequence of elements to run in order - REPEAT - do something N times (and return None or maybe the last CALL?) - LOOP - do something an infinite number of times - CHOOSE - like a switch statement (should rename to SWITCH) - WEAVE - interleave execution of multiple control-flow elements - POPEN - launch a process and return its status when it's complete - PRINT - a shortcut for CALL(print_obj) - SPAWN - run a program fragment asynchronously in another process We don't have many requirements per-se for the architecture, but I think this design respects and realizes all of them. The advantages of this approach are: - algorithms (including partially run ones) are COPYABLE, and SERIALIZABLE - algorithms can be executed without seizing control of the python process (the run() method does this, but if you look inside it you'll see it's a simple for loop) - it is easy to execute an algorithm step by step in a main loop that also checks for network or filesystem events related to e.g. job management. - the library can provide learning algorithms via control-flow templates, and the user can edit them (with search/replace calls) to include HOOKS, and DIAGNOSTIC plug-in functionality e.g. prog.find(CALL(cd1_update, layer=layer1)).replace_with( SEQ([CALL(cd1_update, layer=layer1), CALL(my_debugfn)])) - user can print the 'program code' of an algorithm built from library pieces - program can be optimized automatically. - e.g. BUFFER(N, CALL(dataset.next)) could be replaced if dataset.next implements the right attribute/protocol for 'bufferable' or something. - e.g. SEQ([a,b,c,d]) could be compiled to a single CALL to a Theano-compiled function if a, b, c, and d are calls to callable objects that export something like a 'theano_SEQ' interface """ __license__ = 'TODO' __copyright__ = 'TODO' import cPickle, copy, os, subprocess, sys, time import numpy #################################################### # CONTROL-FLOW CONSTRUCTS class INCOMPLETE: """Return value for Element.step""" class ELEMENT(object): """ Base class for control flow elements (e.g. CALL, REPEAT, etc.) The design is that every element has a driver, that is another element, or the iterator implementation in the ELEMENT class. the driver calls start when entering a new control element - this would be called once per e.g. outer loop iteration the driver calls step to advance the control element - which returns INCOMPLETE - which returns any other object to indicate completion """ # subclasses should override these methods: def start(self): pass def step(self): pass # subclasses should typically not override these: def run(self, n_steps=float('inf')): self.start() i = 0 r = self.step() while r is INCOMPLETE: i += 1 #TODO make sure there is not an off-by-one error if i > n_steps: break r = self.step() return r class BUFFER_REPEAT(ELEMENT): """ Accumulate a number of return values into one list / array. The source of return values `src` is a control element that will be restarted repeatedly in order to fulfil the requiement of gathering N samples. TODO: support accumulating of tuples of arrays """ def __init__(self, N, src, storage=None): """ TODO: use preallocated `storage` """ self.N = N self.n = 0 self.src = src self.storage = storage self.src.start(None) if self.storage != None: raise NotImplementedError() def start(self, arg): self.buf = [None] * self.N self.n = 0 self.finished = False def step(self): assert not self.finished r = self.src.step() if r is INCOMPLETE: return r self.src.start(None) # restart our stream self.buf[self.n] = r self.n += 1 if self.n == self.N: self.finished = True return self.buf else: return INCOMPLETE assert 0 class CALL(ELEMENT): """ Control flow terminal - call a python function or method. Returns the return value of the call. """ def __init__(self, fn, *args, **kwargs): self.fn = fn self.args = args self.kwargs=kwargs def start(self): self.finished = False return self def step(self): assert not self.finished self.finished = True fn_rval = self.fn(*self.lookup_args(), **self.lookup_kwargs()) if '_set' in self.kwargs: self.kwargs['_set'].set(fn_rval) def __getstate__(self): rval = dict(self.__dict__) if type(self.fn) is type(self.step): #instancemethod fn = rval.pop('fn') rval['i fn'] = fn.im_func, fn.im_self, fn.im_class return rval def __setstate__(self, dct): if 'i fn' in dct: dct['fn'] = type(self.step)(*dct.pop('i fn')) self.__dict__.update(dct) def lookup_args(self): rval = [] for a in self.args: if isinstance(a, Register): rval.append(a.get()) else: rval.append(a) return rval def lookup_kwargs(self): rval = {} for k,v in self.kwargs.iteritems(): if k == '_set': continue if isinstance(v, Register): rval[k] = v.get() else: rval[k] = v return rval def CHOOSE(which, options): """ Execute one out of a number of optional control flow paths """ raise NotImplementedError() def LOOP(element): #TODO: implement a true infinite loop return REPEAT(sys.maxint, element) class REPEAT(ELEMENT): def __init__(self, N, element, counter=None): self.N = N if not isinstance(element, ELEMENT): raise TypeError(element) self.element = element self.counter = counter #TODO: check for N being callable def start(self): self.n = 0 #loop iteration self.finished = False self.element.start() if self.counter: self.counter.set(0) def step(self): assert not self.finished r = self.element.step() if r is INCOMPLETE: return INCOMPLETE self.n += 1 if self.counter: self.counter.set(self.n) if self.n < self.N: self.element.start() return INCOMPLETE else: self.finished = True return r class SEQ(ELEMENT): def __init__(self, elements): self.elements = list(elements) def start(self): if len(self.elements): self.elements[0].start() self.pos = 0 self.finished = False def step(self): if self.pos == len(self.elements): self.finished=True return r = self.elements[self.pos].step() if r is INCOMPLETE: return r self.pos += 1 if self.pos < len(self.elements): self.elements[self.pos].start() return INCOMPLETE class WEAVE(ELEMENT): """ Interleave execution of a number of elements. TODO: allow a schedule (at least relative frequency) of elements from each program """ def __init__(self, n_required, elements): self.elements = elements if n_required == -1: self.n_required = len(elements) else: self.n_required = n_required def start(self): for el in self.elements: el.start() self.elem_finished = [0] * len(self.elements) self.idx = 0 self.finished= False def step(self): assert not self.finished # if this is triggered, we have a broken driver #start with this check in case there were no elements # it's possible for the number of finished elements to exceed the threshold if sum(self.elem_finished) >= self.n_required: self.finished = True return None # step the active element r = self.elements[self.idx].step() if r is not INCOMPLETE: self.elem_finished[self.idx] = True # check for completion if sum(self.elem_finished) >= self.n_required: self.finished = True return None # advance to the next un-finished element self.idx = (self.idx+1) % len(self.elements) while self.elem_finished[self.idx]: self.idx = (self.idx+1) % len(self.elements) return INCOMPLETE class POPEN(ELEMENT): def __init__(self, args): self.args = args def start(self): self.p = subprocess.Popen(self.args) def step(self): r = self.p.poll() if r is None: return INCOMPLETE return r def PRINT(obj): return CALL(print_obj, obj) class SPAWN(ELEMENT): SUCCESS = 0 def __init__(self, data, prog): self.data = data self.prog = prog def start(self): # pickle the (data, prog) pair s = cPickle.dumps((self.data, self.prog)) # call python with a stub function that # unpickles the data, prog pair and starts running the prog self.rpipe, wpipe = os.pipe() code = 'import sys, plugin_JB; sys.exit(plugin_JB.SPAWN._main(%i))'%wpipe self.p = subprocess.Popen( ['python', '-c', code], stdin=subprocess.PIPE) # send the data and prog to the other process self.p.stdin.write(s) self.finished= False #TODO: send over tgz of the modules this code needs #TODO: When the client process is on a different machine, negotiate with the client # process to determine which modules it needs, and send over the code for pure python # ones. Make sure versions match for non-pure python ones. def step(self): assert not self.finished r = self.p.poll() if r is None: return INCOMPLETE # typical exit case self.finished = True if r != self.SUCCESS: print "UH OH", r # TODO - ??? rfile = os.fdopen(self.rpipe) # recv the revised of the data dictionary data = cPickle.load(rfile) # modify the data dict in-place # for new values to be visible to other components self.data.update(data) rfile.close() #TODO: return something meaningful? like r? return None @staticmethod def _main(wpipe): #TODO: unpack and install tgz of the modules this code needs data, prog = cPickle.load(sys.stdin) rval = prog.run() os.write(wpipe, cPickle.dumps(data)) return SPAWN.SUCCESS #os.close(wpipe) class Register(object): def __init__(self, registers, key): self.registers = registers self.key = key def set(self, val): self.registers[self.key] = val def get(self): return self.registers[self.key] class Registers(dict): def __call__(self, key): return Register(self, key) def print_obj(obj): print obj def print_obj_attr(obj, attr): print getattr(obj, attr) def no_op(*args, **kwargs): pass def importable_fn(d): d['new key'] = len(d) if __name__ == '__main__': print 'this is the library file, run "python plugin_JB_main.py"'