Mercurial > pylearn
comparison doc/v2_planning/plugin_JB.py @ 1198:1387771296a8
v2planning adding plugin_JB
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 20 Sep 2010 02:34:23 -0400 |
parents | |
children | 98954d8cb92d |
comparison
equal
deleted
inserted
replaced
1197:a60b3472c4ba | 1198:1387771296a8 |
---|---|
1 """plugin_JB - draft of library architecture using iterators""" | |
2 | |
3 | |
4 """ | |
5 | |
6 - PICKLABLE - algorithms are serializable at all points during execution | |
7 | |
8 - ITERATOR walks through algorithms with fine granularity | |
9 | |
10 - COMPONENTS - library provides components on which programs operate | |
11 | |
12 - ALGORITHMS - library provides algorithms in clean (no hooks) form | |
13 | |
14 - HOOKS - user can insert print / debug logic with search/replace type calls | |
15 e.g. prog.find(CALL(cd1_update)).replace_with(SEQ([CALL(cd1_update), CALL(debugfn)])) | |
16 | |
17 - PRINTING - user can print the 'program code' of an algorithm built from library pieces | |
18 | |
19 - MODULAR EXPERIMENTS - an experiment object with one (or more?) programs and all of the objects referred to by | |
20 those programs. It is the preferred type of object to be serialized. The main components of | |
21 the algorithms should be top-level attributes of the package. This object can be serialized | |
22 and loaded in another process to implement job migration. | |
23 | |
24 - OPTIMIZATION - program can be optimized automatically | |
25 e.g. BUFFER(N, CALL(dataset.next)) can be replaced if dataset.next implements the right | |
26 attribute/protocol for 'bufferable' or something. | |
27 | |
28 e.g. SEQ([a,b,c,d]) can be compiled with Theano if sub-sequence is compatible | |
29 | |
30 - don't need greenlets to get efficiency, the implementations of control flow ops can manage a | |
31 stack or stack tree in the vm (like greenlets do I think) we don't really need | |
32 greenlets/stackless I don't think | |
33 | |
34 """ | |
35 | |
36 __license__ = None | |
37 __copyright__ = None | |
38 | |
39 import copy, sys, cPickle | |
40 | |
41 import numpy | |
42 | |
43 | |
44 ################################################### | |
45 # Virtual Machine for executing programs | |
46 | |
47 class VirtualMachine(object): | |
48 def __init__(self, prog): | |
49 self.prog = prog | |
50 self.started = False | |
51 self.finished=False | |
52 def __iter__(self): | |
53 assert not self.started | |
54 self.prog.start(None) | |
55 self.started = True | |
56 return self | |
57 def next(self): | |
58 if self.finished: | |
59 raise StopIteration() | |
60 r = self.prog.step() | |
61 if r is INCOMPLETE: | |
62 return r | |
63 else: | |
64 self.finished=True | |
65 return r | |
66 def run(self,n_steps=float('inf')): | |
67 i = 0 | |
68 for r in self: | |
69 i += 1 | |
70 if i > n_steps: | |
71 break | |
72 return r | |
73 | |
74 | |
75 #################################################### | |
76 # CONTROL-FLOW CONSTRUCTS | |
77 | |
78 class INCOMPLETE: | |
79 """Return value for Element.step""" | |
80 | |
81 class ELEMENT(object): | |
82 """ | |
83 every execution block has a driver | |
84 | |
85 the driver calls start when entering a new control element | |
86 - this would be called once per e.g. outer loop iteration | |
87 | |
88 the driver calls step to advance the control element | |
89 - which returns INCOMPLETE | |
90 - which returns any other object to indicate completion | |
91 """ | |
92 | |
93 def start(self, arg): | |
94 pass | |
95 def step(self): | |
96 pass | |
97 | |
98 | |
99 class BUFFER_REPEAT(ELEMENT): | |
100 """ | |
101 Accumulate a number of return values into one list / array. | |
102 | |
103 The source of return values `src` is a control element that will be restarted repeatedly in | |
104 order to fulfil the requiement of gathering N samples. | |
105 | |
106 TODO: support accumulating of tuples of arrays | |
107 """ | |
108 def __init__(self, N, src, storage=None): | |
109 """ | |
110 TODO: use preallocated `storage` | |
111 """ | |
112 self.N = N | |
113 self.n = 0 | |
114 self.src = src | |
115 self.storage = storage | |
116 self.src.start(None) | |
117 if self.storage != None: | |
118 raise NotImplementedError() | |
119 def start(self, arg): | |
120 self.buf = [None] * self.N | |
121 self.n = 0 | |
122 self.finished = False | |
123 def step(self): | |
124 assert not self.finished | |
125 r = self.src.step() | |
126 if r is INCOMPLETE: | |
127 return r | |
128 self.src.start(None) # restart our stream | |
129 self.buf[self.n] = r | |
130 self.n += 1 | |
131 if self.n == self.N: | |
132 self.finished = True | |
133 return self.buf | |
134 else: | |
135 return INCOMPLETE | |
136 assert 0 | |
137 | |
138 class CALL(ELEMENT): | |
139 """ | |
140 Control flow terminal - call a python function or method. | |
141 | |
142 Returns the return value of the call. | |
143 """ | |
144 def __init__(self, fn, *args, **kwargs): | |
145 self.fn = fn | |
146 self.args = args | |
147 self.kwargs=kwargs | |
148 self.use_start_arg = kwargs.pop('use_start_arg', False) | |
149 def start(self, arg): | |
150 self.start_arg = arg | |
151 self.finished = False | |
152 return self | |
153 def step(self): | |
154 assert not self.finished | |
155 self.finished = True | |
156 if self.use_start_arg: | |
157 if self.args: | |
158 raise TypeError('cant get positional args both ways') | |
159 return self.fn(self.start_arg, **self.kwargs) | |
160 else: | |
161 return self.fn(*self.args, **self.kwargs) | |
162 def __getstate__(self): | |
163 rval = self.__dict__ | |
164 if type(self.fn) is type(self.step): #instancemethod | |
165 fn = rval.pop('fn') | |
166 rval['i fn'] = fn.im_func, fn.im_self, fn.im_class | |
167 return rval | |
168 def __setstate__(self, dct): | |
169 if 'i fn' in dct: | |
170 dct['fn'] = type(self.step)(*dct.pop('i fn')) | |
171 self.__dict__.update(dct) | |
172 | |
173 def FILT(*args, **kwargs): | |
174 return CALL(use_start_arg=True, *args, **kwargs) | |
175 | |
176 def CHOOSE(which, options): | |
177 """ | |
178 Execute one out of a number of optional control flow paths | |
179 """ | |
180 raise NotImplementedError() | |
181 | |
182 def LOOP(elements): | |
183 #TODO: implement a true infinite loop | |
184 try: | |
185 iter(elements) | |
186 return REPEAT(sys.maxint, elements) | |
187 except TypeError: | |
188 return REPEAT(sys.maxint, [elements]) | |
189 | |
190 class REPEAT(ELEMENT): | |
191 def __init__(self, N, elements, pass_rvals=False): | |
192 self.N = N | |
193 self.elements = elements | |
194 self.pass_rvals = pass_rvals | |
195 #TODO: check for N being callable | |
196 def start(self, arg): | |
197 self.n = 0 #loop iteration | |
198 self.idx = 0 #element idx | |
199 self.finished = False | |
200 self.elements[0].start(arg) | |
201 def step(self): | |
202 assert not self.finished | |
203 r = self.elements[self.idx].step() | |
204 if r is INCOMPLETE: | |
205 return INCOMPLETE | |
206 self.idx += 1 | |
207 if self.idx < len(self.elements): | |
208 self.elements[self.idx].start(r) | |
209 return INCOMPLETE | |
210 self.n += 1 | |
211 if self.n < self.N: | |
212 self.idx = 0 | |
213 self.elements[self.idx].start(r) | |
214 return INCOMPLETE | |
215 else: | |
216 self.finished = True | |
217 return r | |
218 | |
219 def SEQ(elements): | |
220 return REPEAT(1, elements) | |
221 | |
222 class WEAVE(ELEMENT): | |
223 """ | |
224 Interleave execution of a number of elements. | |
225 | |
226 TODO: allow a schedule (at least relative frequency) of elements from each program | |
227 """ | |
228 def __init__(self, elements): | |
229 self.elements = elements | |
230 def start(self, arg): | |
231 for el in self.elements: | |
232 el.start(arg) | |
233 self.idx = 0 | |
234 self.any_is_finished = False | |
235 self.finished= False | |
236 def step(self): | |
237 assert not self.finished # if this is triggered, we have a broken driver | |
238 self.idx = self.idx % len(self.elements) | |
239 r = self.elements[self.idx].step() | |
240 if r is not INCOMPLETE: | |
241 self.any_is_finished = True | |
242 self.idx += 1 | |
243 if self.idx == len(self.elements) and self.any_is_finished: | |
244 self.finished = True | |
245 return None # dummy completion value | |
246 else: | |
247 return INCOMPLETE | |
248 | |
249 | |
250 #################################################### | |
251 # [Dummy] Components involved in learning algorithms | |
252 | |
253 class Dataset(object): | |
254 def __init__(self, data): | |
255 self.pos = 0 | |
256 self.data = data | |
257 def next(self): | |
258 rval = self.data[self.pos] | |
259 self.pos += 1 | |
260 if self.pos == len(self.data): | |
261 self.pos = 0 | |
262 return rval | |
263 def seek(self, pos): | |
264 self.pos = pos | |
265 | |
266 class KFold(object): | |
267 def __init__(self, data, K): | |
268 self.data = data | |
269 self.k = -1 | |
270 self.scores = [None]*K | |
271 self.K = K | |
272 def next_fold(self): | |
273 self.k += 1 | |
274 self.data.seek(0) # restart the stream | |
275 def next(self): | |
276 #TODO: skip the examples that are ommitted in this split | |
277 return self.data.next() | |
278 def init_test(self): | |
279 pass | |
280 def next_test(self): | |
281 return self.data.next() | |
282 def test_size(self): | |
283 return 5 | |
284 def store_scores(self, scores): | |
285 self.scores[self.k] = scores | |
286 | |
287 class PCA_Analysis(object): | |
288 def __init__(self): | |
289 self.clear() | |
290 | |
291 def clear(self): | |
292 self.mean = 0 | |
293 self.eigvecs=0 | |
294 self.eigvals=0 | |
295 def analyze(self, X): | |
296 self.mean = numpy.mean(X, axis=0) | |
297 self.eigvecs=1 | |
298 self.eigvals=1 | |
299 def filt(self, X): | |
300 return (X - self.mean) * self.eigvecs #TODO: divide by root eigvals? | |
301 def pseudo_inverse(self, Y): | |
302 return Y | |
303 | |
304 class Layer(object): | |
305 def __init__(self, w): | |
306 self.w = w | |
307 def filt(self, x): | |
308 return self.w*x | |
309 def clear(self): | |
310 self.w =0 | |
311 | |
312 def print_obj(obj): | |
313 print obj | |
314 def print_obj_attr(obj, attr): | |
315 print getattr(obj, attr) | |
316 def no_op(*args, **kwargs): | |
317 pass | |
318 | |
319 class cd1_update(object): | |
320 def __init__(self, layer, lr): | |
321 self.layer = layer | |
322 self.lr = lr | |
323 | |
324 def __call__(self, X): | |
325 # update self.layer from observation X | |
326 self.layer.w += X.mean() * self.lr #TODO: not exactly correct math | |
327 | |
328 def simple_main(): | |
329 | |
330 l = [0] | |
331 def f(a): | |
332 print l | |
333 l[0] += a | |
334 return l[0] | |
335 | |
336 print VirtualMachine(WEAVE([ | |
337 BUFFER_REPEAT(3,CALL(f,1)), | |
338 BUFFER_REPEAT(5,CALL(f,1)), | |
339 ])).run() | |
340 | |
341 def main(): | |
342 # create components | |
343 dataset = Dataset(numpy.random.RandomState(123).randn(13,1)) | |
344 pca = PCA_Analysis() | |
345 layer1 = Layer(w=4) | |
346 layer2 = Layer(w=3) | |
347 kf = KFold(dataset, K=10) | |
348 | |
349 # create algorithm | |
350 | |
351 train_pca = SEQ([ | |
352 BUFFER_REPEAT(1000, CALL(kf.next)), | |
353 FILT(pca.analyze)]) | |
354 | |
355 train_layer1 = REPEAT(10, [ | |
356 BUFFER_REPEAT(10, CALL(kf.next)), | |
357 FILT(pca.filt), | |
358 FILT(cd1_update(layer1, lr=.01))]) | |
359 | |
360 train_layer2 = REPEAT(10, [ | |
361 BUFFER_REPEAT(10, CALL(kf.next)), | |
362 FILT(pca.filt), | |
363 FILT(layer1.filt), | |
364 FILT(cd1_update(layer2, lr=.01))]) | |
365 | |
366 train_prog = SEQ([ | |
367 train_pca, | |
368 WEAVE([ | |
369 train_layer1, | |
370 LOOP(CALL(print_obj_attr, layer1, 'w'))]), | |
371 train_layer2, | |
372 ]) | |
373 | |
374 kfold_prog = REPEAT(10, [ | |
375 CALL(kf.next_fold), | |
376 CALL(pca.clear), | |
377 CALL(layer1.clear), | |
378 CALL(layer2.clear), | |
379 train_prog, | |
380 CALL(kf.init_test), | |
381 BUFFER_REPEAT(kf.test_size(), | |
382 SEQ([ | |
383 CALL(kf.next_test), | |
384 FILT(pca.filt), # may want to allow this SEQ to be | |
385 FILT(layer1.filt), # optimized into a shorter one that | |
386 FILT(layer2.filt), | |
387 FILT(numpy.mean)])), # chains together theano graphs | |
388 FILT(kf.store_scores), | |
389 ]) | |
390 | |
391 vm = VirtualMachine(kfold_prog) | |
392 | |
393 #vm2 = copy.deepcopy(vm) | |
394 vm.run(n_steps=200000) | |
395 print kf.scores | |
396 | |
397 | |
398 if __name__ == '__main__': | |
399 sys.exit(main()) | |
400 |