Mercurial > pylearn
changeset 1210:cbe1fb32686c
v2planning plugin_JB - added n_required keyword to WEAVE
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 21 Sep 2010 23:38:53 -0400 |
parents | 5ff1d375fc33 |
children | e7ac87720fee |
files | doc/v2_planning/plugin_JB.py |
diffstat | 1 files changed, 30 insertions(+), 12 deletions(-) [+] |
line wrap: on
line diff
--- a/doc/v2_planning/plugin_JB.py Tue Sep 21 16:27:47 2010 -0400 +++ b/doc/v2_planning/plugin_JB.py Tue Sep 21 23:38:53 2010 -0400 @@ -246,26 +246,44 @@ TODO: allow a schedule (at least relative frequency) of elements from each program """ - def __init__(self, elements): + def __init__(self, n_required, elements): self.elements = elements + if n_required == -1: + self.n_required = len(elements) + else: + self.n_required = n_required def start(self, arg): for el in self.elements: el.start(arg) + self.elem_finished = [0] * len(self.elements) self.idx = 0 - self.any_is_finished = False self.finished= False def step(self): assert not self.finished # if this is triggered, we have a broken driver - self.idx = self.idx % len(self.elements) + + #start with this check in case there were no elements + # it's possible for the number of finished elements to exceed the threshold + if sum(self.elem_finished) >= self.n_required: + self.finished = True + return None + + # step the active element r = self.elements[self.idx].step() + if r is not INCOMPLETE: - self.any_is_finished = True - self.idx += 1 - if self.idx == len(self.elements) and self.any_is_finished: - self.finished = True - return None # dummy completion value - else: - return INCOMPLETE + self.elem_finished[self.idx] = True + + # check for completion + if sum(self.elem_finished) >= self.n_required: + self.finished = True + return None + + # advance to the next un-finished element + self.idx = (self.idx+1) % len(self.elements) + while self.elem_finished[self.idx]: + self.idx = (self.idx+1) % len(self.elements) + + return INCOMPLETE #################################################### @@ -359,7 +377,7 @@ l[0] += a return l[0] - print WEAVE([ + print WEAVE(1, [ BUFFER_REPEAT(3,CALL(f,1)), BUFFER_REPEAT(5,CALL(f,1)), ]).run() @@ -402,7 +420,7 @@ ]), train = SEQ([ train_pca, - WEAVE([ # Silly example of how to do debugging / loggin with WEAVE + WEAVE(1, [ # Silly example of how to do debugging / loggin with WEAVE train_layer1, LOOP(CALL(print_obj_attr, layer1, 'w'))]), train_layer2,