Mercurial > pylearn
diff doc/v2_planning/arch_src/plugin_JB.py @ 1219:9fac28d80fb7
plugin_JB - removed FILT and BUFFER_REPEAT, added Registers
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 22 Sep 2010 13:31:31 -0400 |
parents | 478bb1f8215c |
children |
line wrap: on
line diff
--- a/doc/v2_planning/arch_src/plugin_JB.py Wed Sep 22 13:00:39 2010 -0400 +++ b/doc/v2_planning/arch_src/plugin_JB.py Wed Sep 22 13:31:31 2010 -0400 @@ -10,28 +10,29 @@ dataset = Dataset(numpy.random.RandomState(123).randn(13,1)) pca = PCA_Analysis() pca_batchsize=1000 + + reg = Registers() # define the control-flow of the algorithm train_pca = SEQ([ - BUFFER_REPEAT(pca_batchsize, CALL(dataset.next)), - FILT(pca.analyze)]) + REPEAT(pca_batchsize, CALL(dataset.next, store_to=reg('x'))), + CALL(pca.analyze, reg('x'))]) # run the program train_pca.run() -The CALL, SEQ, FILT, and BUFFER_REPEAT are control-flow elements. The control-flow elements I +The CALL, SEQ, and REPEAT are control-flow elements. The control-flow elements I defined so far are: - CALL - a basic statement, just calls a python function -- FILT - like call, but passes the return value of the last CALL or FILT to the python function - SEQ - a sequence of elements to run in order - REPEAT - do something N times (and return None or maybe the last CALL?) -- BUFFER_REPEAT - do something N times and accumulate the return value from each iter - LOOP - do something an infinite number of times - CHOOSE - like a switch statement (should rename to SWITCH) - WEAVE - interleave execution of multiple control-flow elements - POPEN - launch a process and return its status when it's complete - PRINT - a shortcut for CALL(print_obj) +- SPAWN - run a program fragment asynchronously in another process We don't have many requirements per-se for the architecture, but I think this design respects @@ -95,14 +96,14 @@ """ # subclasses should override these methods: - def start(self, arg): + def start(self): pass def step(self): pass # subclasses should typically not override these: - def run(self, arg=None, n_steps=float('inf')): - self.start(arg) + def run(self, n_steps=float('inf')): + self.start() i = 0 r = self.step() while r is INCOMPLETE: @@ -162,20 +163,15 @@ self.fn = fn self.args = args self.kwargs=kwargs - self.use_start_arg = kwargs.pop('use_start_arg', False) - def start(self, arg): - self.start_arg = arg + def start(self): self.finished = False return self def step(self): assert not self.finished self.finished = True - if self.use_start_arg: - if self.args: - raise TypeError('cant get positional args both ways') - return self.fn(self.start_arg, **self.kwargs) - else: - return self.fn(*self.args, **self.kwargs) + fn_rval = self.fn(*self.lookup_args(), **self.lookup_kwargs()) + if '_set' in self.kwargs: + self.kwargs['_set'].set(fn_rval) def __getstate__(self): rval = dict(self.__dict__) if type(self.fn) is type(self.step): #instancemethod @@ -187,12 +183,24 @@ dct['fn'] = type(self.step)(*dct.pop('i fn')) self.__dict__.update(dct) -def FILT(fn, **kwargs): - """ - Return a CALL object that uses the return value from the previous CALL as the first and - only positional argument. - """ - return CALL(fn, use_start_arg=True, **kwargs) + def lookup_args(self): + rval = [] + for a in self.args: + if isinstance(a, Register): + rval.append(a.get()) + else: + rval.append(a) + return rval + def lookup_kwargs(self): + rval = {} + for k,v in self.kwargs.iteritems(): + if k == '_set': + continue + if isinstance(v, Register): + rval[k] = v.get() + else: + rval[k] = v + return rval def CHOOSE(which, options): """ @@ -200,46 +208,60 @@ """ raise NotImplementedError() -def LOOP(elements): +def LOOP(element): #TODO: implement a true infinite loop - try: - iter(elements) - return REPEAT(sys.maxint, elements) - except TypeError: - return REPEAT(sys.maxint, [elements]) + return REPEAT(sys.maxint, element) class REPEAT(ELEMENT): - def __init__(self, N, elements, pass_rvals=False): + def __init__(self, N, element, counter=None): self.N = N - self.elements = elements - self.pass_rvals = pass_rvals + if not isinstance(element, ELEMENT): + raise TypeError(element) + self.element = element + self.counter = counter #TODO: check for N being callable - def start(self, arg): + def start(self): self.n = 0 #loop iteration - self.idx = 0 #element idx self.finished = False - self.elements[0].start(arg) + self.element.start() + if self.counter: + self.counter.set(0) + def step(self): assert not self.finished - r = self.elements[self.idx].step() + r = self.element.step() if r is INCOMPLETE: return INCOMPLETE - self.idx += 1 - if self.idx < len(self.elements): - self.elements[self.idx].start(r) - return INCOMPLETE self.n += 1 + if self.counter: + self.counter.set(self.n) if self.n < self.N: - self.idx = 0 - self.elements[self.idx].start(r) + self.element.start() return INCOMPLETE else: self.finished = True return r -def SEQ(elements): - return REPEAT(1, elements) +class SEQ(ELEMENT): + def __init__(self, elements): + self.elements = list(elements) + def start(self): + if len(self.elements): + self.elements[0].start() + self.pos = 0 + self.finished = False + def step(self): + if self.pos == len(self.elements): + self.finished=True + return + r = self.elements[self.pos].step() + if r is INCOMPLETE: + return r + self.pos += 1 + if self.pos < len(self.elements): + self.elements[self.pos].start() + return INCOMPLETE class WEAVE(ELEMENT): """ @@ -253,9 +275,9 @@ self.n_required = len(elements) else: self.n_required = n_required - def start(self, arg): + def start(self): for el in self.elements: - el.start(arg) + el.start() self.elem_finished = [0] * len(self.elements) self.idx = 0 self.finished= False @@ -289,7 +311,7 @@ class POPEN(ELEMENT): def __init__(self, args): self.args = args - def start(self, arg): + def start(self): self.p = subprocess.Popen(self.args) def step(self): r = self.p.poll() @@ -305,7 +327,7 @@ def __init__(self, data, prog): self.data = data self.prog = prog - def start(self, arg): + def start(self): # pickle the (data, prog) pair s = cPickle.dumps((self.data, self.prog)) @@ -353,6 +375,17 @@ return SPAWN.SUCCESS #os.close(wpipe) +class Register(object): + def __init__(self, registers, key): + self.registers = registers + self.key = key + def set(self, val): + self.registers[self.key] = val + def get(self): + return self.registers[self.key] +class Registers(dict): + def __call__(self, key): + return Register(self, key) def print_obj(obj): print obj @@ -364,3 +397,6 @@ def importable_fn(d): d['new key'] = len(d) + +if __name__ == '__main__': + print 'this is the library file, run "python plugin_JB_main.py"'