Mercurial > pylearn
changeset 1219:9fac28d80fb7
plugin_JB - removed FILT and BUFFER_REPEAT, added Registers
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 22 Sep 2010 13:31:31 -0400 |
parents | 5d1b5906151c |
children | 35fb6e9713d2 |
files | doc/v2_planning/arch_src/plugin_JB.py doc/v2_planning/arch_src/plugin_JB_main.py |
diffstat | 2 files changed, 139 insertions(+), 87 deletions(-) [+] |
line wrap: on
line diff
--- a/doc/v2_planning/arch_src/plugin_JB.py Wed Sep 22 13:00:39 2010 -0400 +++ b/doc/v2_planning/arch_src/plugin_JB.py Wed Sep 22 13:31:31 2010 -0400 @@ -10,28 +10,29 @@ dataset = Dataset(numpy.random.RandomState(123).randn(13,1)) pca = PCA_Analysis() pca_batchsize=1000 + + reg = Registers() # define the control-flow of the algorithm train_pca = SEQ([ - BUFFER_REPEAT(pca_batchsize, CALL(dataset.next)), - FILT(pca.analyze)]) + REPEAT(pca_batchsize, CALL(dataset.next, store_to=reg('x'))), + CALL(pca.analyze, reg('x'))]) # run the program train_pca.run() -The CALL, SEQ, FILT, and BUFFER_REPEAT are control-flow elements. The control-flow elements I +The CALL, SEQ, and REPEAT are control-flow elements. The control-flow elements I defined so far are: - CALL - a basic statement, just calls a python function -- FILT - like call, but passes the return value of the last CALL or FILT to the python function - SEQ - a sequence of elements to run in order - REPEAT - do something N times (and return None or maybe the last CALL?) -- BUFFER_REPEAT - do something N times and accumulate the return value from each iter - LOOP - do something an infinite number of times - CHOOSE - like a switch statement (should rename to SWITCH) - WEAVE - interleave execution of multiple control-flow elements - POPEN - launch a process and return its status when it's complete - PRINT - a shortcut for CALL(print_obj) +- SPAWN - run a program fragment asynchronously in another process We don't have many requirements per-se for the architecture, but I think this design respects @@ -95,14 +96,14 @@ """ # subclasses should override these methods: - def start(self, arg): + def start(self): pass def step(self): pass # subclasses should typically not override these: - def run(self, arg=None, n_steps=float('inf')): - self.start(arg) + def run(self, n_steps=float('inf')): + self.start() i = 0 r = self.step() while r is INCOMPLETE: @@ -162,20 +163,15 @@ self.fn = fn self.args = args self.kwargs=kwargs - self.use_start_arg = kwargs.pop('use_start_arg', False) - def start(self, arg): - self.start_arg = arg + def start(self): self.finished = False return self def step(self): assert not self.finished self.finished = True - if self.use_start_arg: - if self.args: - raise TypeError('cant get positional args both ways') - return self.fn(self.start_arg, **self.kwargs) - else: - return self.fn(*self.args, **self.kwargs) + fn_rval = self.fn(*self.lookup_args(), **self.lookup_kwargs()) + if '_set' in self.kwargs: + self.kwargs['_set'].set(fn_rval) def __getstate__(self): rval = dict(self.__dict__) if type(self.fn) is type(self.step): #instancemethod @@ -187,12 +183,24 @@ dct['fn'] = type(self.step)(*dct.pop('i fn')) self.__dict__.update(dct) -def FILT(fn, **kwargs): - """ - Return a CALL object that uses the return value from the previous CALL as the first and - only positional argument. - """ - return CALL(fn, use_start_arg=True, **kwargs) + def lookup_args(self): + rval = [] + for a in self.args: + if isinstance(a, Register): + rval.append(a.get()) + else: + rval.append(a) + return rval + def lookup_kwargs(self): + rval = {} + for k,v in self.kwargs.iteritems(): + if k == '_set': + continue + if isinstance(v, Register): + rval[k] = v.get() + else: + rval[k] = v + return rval def CHOOSE(which, options): """ @@ -200,46 +208,60 @@ """ raise NotImplementedError() -def LOOP(elements): +def LOOP(element): #TODO: implement a true infinite loop - try: - iter(elements) - return REPEAT(sys.maxint, elements) - except TypeError: - return REPEAT(sys.maxint, [elements]) + return REPEAT(sys.maxint, element) class REPEAT(ELEMENT): - def __init__(self, N, elements, pass_rvals=False): + def __init__(self, N, element, counter=None): self.N = N - self.elements = elements - self.pass_rvals = pass_rvals + if not isinstance(element, ELEMENT): + raise TypeError(element) + self.element = element + self.counter = counter #TODO: check for N being callable - def start(self, arg): + def start(self): self.n = 0 #loop iteration - self.idx = 0 #element idx self.finished = False - self.elements[0].start(arg) + self.element.start() + if self.counter: + self.counter.set(0) + def step(self): assert not self.finished - r = self.elements[self.idx].step() + r = self.element.step() if r is INCOMPLETE: return INCOMPLETE - self.idx += 1 - if self.idx < len(self.elements): - self.elements[self.idx].start(r) - return INCOMPLETE self.n += 1 + if self.counter: + self.counter.set(self.n) if self.n < self.N: - self.idx = 0 - self.elements[self.idx].start(r) + self.element.start() return INCOMPLETE else: self.finished = True return r -def SEQ(elements): - return REPEAT(1, elements) +class SEQ(ELEMENT): + def __init__(self, elements): + self.elements = list(elements) + def start(self): + if len(self.elements): + self.elements[0].start() + self.pos = 0 + self.finished = False + def step(self): + if self.pos == len(self.elements): + self.finished=True + return + r = self.elements[self.pos].step() + if r is INCOMPLETE: + return r + self.pos += 1 + if self.pos < len(self.elements): + self.elements[self.pos].start() + return INCOMPLETE class WEAVE(ELEMENT): """ @@ -253,9 +275,9 @@ self.n_required = len(elements) else: self.n_required = n_required - def start(self, arg): + def start(self): for el in self.elements: - el.start(arg) + el.start() self.elem_finished = [0] * len(self.elements) self.idx = 0 self.finished= False @@ -289,7 +311,7 @@ class POPEN(ELEMENT): def __init__(self, args): self.args = args - def start(self, arg): + def start(self): self.p = subprocess.Popen(self.args) def step(self): r = self.p.poll() @@ -305,7 +327,7 @@ def __init__(self, data, prog): self.data = data self.prog = prog - def start(self, arg): + def start(self): # pickle the (data, prog) pair s = cPickle.dumps((self.data, self.prog)) @@ -353,6 +375,17 @@ return SPAWN.SUCCESS #os.close(wpipe) +class Register(object): + def __init__(self, registers, key): + self.registers = registers + self.key = key + def set(self, val): + self.registers[self.key] = val + def get(self): + return self.registers[self.key] +class Registers(dict): + def __call__(self, key): + return Register(self, key) def print_obj(obj): print obj @@ -364,3 +397,6 @@ def importable_fn(d): d['new key'] = len(d) + +if __name__ == '__main__': + print 'this is the library file, run "python plugin_JB_main.py"'
--- a/doc/v2_planning/arch_src/plugin_JB_main.py Wed Sep 22 13:00:39 2010 -0400 +++ b/doc/v2_planning/arch_src/plugin_JB_main.py Wed Sep 22 13:31:31 2010 -0400 @@ -10,10 +10,10 @@ def __init__(self, data): self.pos = 0 self.data = data - def next(self): - rval = self.data[self.pos] - self.pos += 1 - if self.pos == len(self.data): + def next(self, n=1): + rval = self.data[self.pos:self.pos+n] + self.pos += n + if self.pos >= len(self.data): self.pos = 0 return rval def seek(self, pos): @@ -28,27 +28,29 @@ def next_fold(self): self.k += 1 self.data.seek(0) # restart the stream - def next(self): + def next(self, n=1): #TODO: skip the examples that are ommitted in this split - return self.data.next() + return self.data.next(n) def init_test(self): pass - def next_test(self): - return self.data.next() + def next_test(self, n=1): + return self.data.next(n) def test_size(self): return 5 def store_scores(self, scores): self.scores[self.k] = scores - def prog(self, clear, train, test): - return REPEAT(self.K, [ + def prog(self, clear, train, test, test_data_reg, test_counter_reg, test_scores_reg): + return REPEAT(self.K, SEQ([ CALL(self.next_fold), clear, train, CALL(self.init_test), - BUFFER_REPEAT(self.test_size(), - SEQ([ CALL(self.next_test), test])), - FILT(self.store_scores) ]) + REPEAT(self.test_size(), SEQ([ + CALL(self.next_test, _set=test_data_reg), + test]), + counter=test_counter_reg), + CALL(self.store_scores, test_scores_reg)])) class PCA_Analysis(object): def __init__(self): @@ -93,8 +95,8 @@ return l[0] print WEAVE(1, [ - BUFFER_REPEAT(3,CALL(f,1)), - BUFFER_REPEAT(5,CALL(f,1)), + REPEAT(3,CALL(f,1)), + REPEAT(5,CALL(f,1)), ]).run() def main_weave_popen(): @@ -104,9 +106,9 @@ p = WEAVE(2,[ SEQ([POPEN(['sleep', '5']), PRINT('done 1')]), SEQ([POPEN(['sleep', '10']), PRINT('done 2')]), - LOOP([ + LOOP(SEQ([ CALL(print_obj, 'polling...'), - CALL(time.sleep, 1)])]) + CALL(time.sleep, 1)]))]) # The LOOP would forever if the WEAVE were not configured to stop after 2 of its elements # complete. @@ -120,15 +122,15 @@ data1 = {0:"blah data1"} data2 = {1:"foo data2"} p = WEAVE(2,[ - SPAWN(data1, REPEAT(3, [ + SPAWN(data1, REPEAT(3, SEQ([ CALL(importable_fn, data1), - PRINT("hello from 1")])), - SPAWN(data2, REPEAT(1, [ + PRINT("hello from 1")]))), + SPAWN(data2, REPEAT(1, SEQ([ CALL(importable_fn, data2), - PRINT("hello from 2")])), - LOOP([ + PRINT("hello from 2")]))), + LOOP(SEQ([ CALL(print_obj, 'polling...'), - CALL(time.sleep, 0.5)])]) + CALL(time.sleep, 0.5)]))]) print 'BEFORE' print data1 print data2 @@ -148,6 +150,7 @@ layer1 = Layer(w=4) layer2 = Layer(w=3) kf = KFold(dataset, K=10) + reg = Registers() pca_batchsize=1000 cd_batchsize = 5 @@ -157,19 +160,19 @@ # create algorithm train_pca = SEQ([ - BUFFER_REPEAT(pca_batchsize, CALL(kf.next)), - FILT(pca.analyze)]) + CALL(kf.next, pca_batchsize, _set=reg('x')), + CALL(pca.analyze, reg('x'))]) - train_layer1 = REPEAT(n_cd_updates_layer1, [ - BUFFER_REPEAT(cd_batchsize, CALL(kf.next)), - FILT(pca.filt), - FILT(cd1_update, layer=layer1, lr=.01)]) + train_layer1 = REPEAT(n_cd_updates_layer1, SEQ([ + CALL(kf.next, cd_batchsize, _set=reg('x')), + CALL(pca.filt, reg('x'), _set=reg('x')), + CALL(cd1_update, reg('x'), layer=layer1, lr=.01)])) - train_layer2 = REPEAT(n_cd_updates_layer2, [ - BUFFER_REPEAT(cd_batchsize, CALL(kf.next)), - FILT(pca.filt), - FILT(layer1.filt), - FILT(cd1_update, layer=layer2, lr=.01)]) + train_layer2 = REPEAT(n_cd_updates_layer2, SEQ([ + CALL(kf.next, cd_batchsize, _set=reg('x')), + CALL(pca.filt, reg('x'), _set=reg('x')), + CALL(layer1.filt, reg('x'), _set=reg('x')), + CALL(cd1_update, reg('x'), layer=layer2, lr=.01)])) kfold_prog = kf.prog( clear = SEQ([ # FRAGMENT 1: this bit is the reset/clear stage @@ -181,14 +184,17 @@ train_pca, WEAVE(1, [ # Silly example of how to do debugging / loggin with WEAVE train_layer1, - LOOP(CALL(print_obj_attr, layer1, 'w'))]), + LOOP(PRINT(reg('x')))]), train_layer2, ]), test=SEQ([ - FILT(pca.filt), # may want to allow this SEQ to be - FILT(layer1.filt), # optimized into a shorter one that - FILT(layer2.filt), # compiles these calls together with - FILT(numpy.mean)])) # Theano + CALL(pca.filt, reg('testx'), _set=reg('x')), + CALL(layer1.filt, reg('x'), _set=reg('x')), + CALL(layer2.filt, reg('x'), _set=reg('x')), + CALL(numpy.mean, reg('x'), _set=reg('score'))]), + test_data_reg=reg('testx'), + test_counter_reg=reg('i'), + test_scores_reg=reg('score')) pkg1 = dict(prog=kfold_prog, kf=kf) pkg2 = copy.deepcopy(pkg1) # programs can be copied @@ -206,4 +212,14 @@ if __name__ == '__main__': + try: + sys.argv[1] + except: + print """You have to tell which main function to use, try: + - python plugin_JB_main.py 'main_kfold_dbn()' + - python plugin_JB_main.py 'main_weave()' + - python plugin_JB_main.py 'main_weave_popen()' + - python plugin_JB_main.py 'main_spawn()' + """ + sys.exit(1) sys.exit(eval(sys.argv[1]))