Mercurial > pylearn
comparison doc/v2_planning/plugin.py @ 1135:a1957faecc9b
revised plugin interface and implementation
author | Olivier Breuleux <breuleuo@iro.umontreal.ca> |
---|---|
date | Thu, 16 Sep 2010 02:58:24 -0400 |
parents | 8cc324f388ba |
children |
comparison
equal
deleted
inserted
replaced
1134:0653a85ff2e8 | 1135:a1957faecc9b |
---|---|
1 | 1 |
2 import time | 2 import time |
3 from collections import defaultdict | 3 from collections import defaultdict, deque |
4 from copy import copy | |
4 | 5 |
5 inf = float('inf') | 6 inf = float('inf') |
7 | |
8 ############# | |
9 ### EVENT ### | |
10 ############# | |
11 | |
12 class Event(object): | |
13 | |
14 def __init__(self, type, **attributes): | |
15 self.type = type | |
16 self.__dict__.update(attributes) | |
17 self.attributes = dict(type = type, **attributes) | |
18 | |
19 def match(self, other): | |
20 if isinstance(other, Matcher): | |
21 return other(self) | |
22 else: | |
23 oattr = other.attributes | |
24 for k, v in self.attributes.iteritems(): | |
25 if k in oattr: | |
26 v2 = oattr[k] | |
27 if isinstance(v2, Matcher): | |
28 if not v2(v): return False | |
29 else: | |
30 if v != v2: return False | |
31 return True | |
32 | |
33 def __str__(self): | |
34 return "Event(%s)" % ", ".join("%s=%s" % (k, v) for k, v in self.attributes.iteritems()) | |
35 | |
36 class Matcher(object): | |
37 | |
38 def __call__(self, object): | |
39 raise NotImplementedError("Implement this!") | |
40 | |
41 class FnMatcher(Matcher): | |
42 | |
43 def __init__(self, function): | |
44 self.function = function | |
45 | |
46 def __call__(self, object): | |
47 return self.function(object) | |
48 | |
49 all_events = FnMatcher(lambda _: True) | |
50 | |
51 | |
6 | 52 |
7 ################ | 53 ################ |
8 ### SCHEDULE ### | 54 ### SCHEDULE ### |
9 ################ | 55 ################ |
10 | 56 |
11 class Schedule(object): | 57 class Schedule(Matcher): |
12 def __add__(self, i): | 58 def __add__(self, i): |
13 return OffsetSchedule(self, i) | 59 return OffsetSchedule(self, i) |
14 def __or__(self, s): | 60 def __or__(self, s): |
15 return UnionSchedule(self, to_schedule(s)) | 61 return UnionSchedule(self, to_schedule(s)) |
16 def __and__(self, s): | 62 def __and__(self, s): |
40 def __init__(self, *subschedules): | 86 def __init__(self, *subschedules): |
41 assert (not self.__n__) or len(subschedules) == self.__n__ | 87 assert (not self.__n__) or len(subschedules) == self.__n__ |
42 self.subschedules = map(to_schedule, subschedules) | 88 self.subschedules = map(to_schedule, subschedules) |
43 | 89 |
44 class UnionSchedule(ScheduleMix): | 90 class UnionSchedule(ScheduleMix): |
45 def __call__(self, t1, t2): | 91 def __call__(self, time): |
46 return any(s(t1, t2) for s in self.subschedules) | 92 return any(s(time) for s in self.subschedules) |
47 | 93 |
48 class IntersectionSchedule(ScheduleMix): | 94 class IntersectionSchedule(ScheduleMix): |
49 def __call__(self, t1, t2): | 95 def __call__(self, time): |
50 return all(s(t1, t2) for s in self.subschedules) | 96 return all(s(time) for s in self.subschedules) |
51 | 97 |
52 class DifferenceSchedule(ScheduleMix): | 98 class DifferenceSchedule(ScheduleMix): |
53 __n__ = 2 | 99 __n__ = 2 |
54 def __call__(self, t1, t2): | 100 def __call__(self, time): |
55 return self.subschedules[0](t1, t2) and not self.subschedules[1](t1, t2) | 101 return self.subschedules[0](time) and not self.subschedules[1](time) |
56 | 102 |
57 class NegatedSchedule(ScheduleMix): | 103 class NegatedSchedule(ScheduleMix): |
58 __n__ = 1 | 104 __n__ = 1 |
59 def __call__(self, t1, t2): | 105 def __call__(self, time): |
60 return not self.subschedules[0](t1, t2) | 106 return not self.subschedules[0](time) |
61 | 107 |
62 class OffsetSchedule(Schedule): | 108 class OffsetSchedule(Schedule): |
63 def __init__(self, schedule, offset): | 109 def __init__(self, schedule, offset): |
64 self.schedule = schedule | 110 self.schedule = schedule |
65 self.offset = offset | 111 self.offset = offset |
66 def __call__(self, t1, t2): | 112 def __call__(self, time): |
67 return self.schedule(t1 - self.offset, t2 - self.offset) | 113 if isinstance(time, int): |
114 return self.schedule(time - self.offset) | |
115 else: | |
116 t1, t2 = time | |
117 return self.schedule((t1 - self.offset, t2 - self.offset)) | |
68 | 118 |
69 | 119 |
70 class AlwaysSchedule(Schedule): | 120 class AlwaysSchedule(Schedule): |
71 def __call__(self, t1, t2): | 121 def __call__(self, time): |
72 return True | 122 return True |
73 | 123 |
74 always = AlwaysSchedule() | 124 always = AlwaysSchedule() |
75 never = ~always | 125 never = ~always |
76 | 126 |
77 class IntervalSchedule(Schedule): | 127 class IntervalSchedule(Schedule): |
78 def __init__(self, step, repeat = inf): | 128 def __init__(self, step, repeat = inf): |
79 self.step = step | 129 self.step = step |
80 self.upper_bound = step * (repeat - 1) | 130 self.upper_bound = step * (repeat - 1) |
81 def __call__(self, t1, t2): | 131 def __call__(self, time): |
82 if t2 < 0 or t1 > self.upper_bound: | 132 if isinstance(time, int): |
83 return False | 133 if time < 0 or time > self.upper_bound: |
84 diff = t2 - t1 | 134 return False |
85 t1m = t1 % self.step | 135 return time % self.step == 0 |
86 t2m = t2 % self.step | 136 else: |
87 return (diff >= self.step | 137 t1, t2 = time |
88 or t1m == 0 | 138 if t2 < 0 or t1 > self.upper_bound: |
89 or t2m == 0 | 139 return False |
90 or t1m > t2m) | 140 diff = t2 - t1 |
141 t1m = t1 % self.step | |
142 t2m = t2 % self.step | |
143 return (diff >= self.step | |
144 or t1m == 0 | |
145 or t2m == 0 | |
146 or t1m > t2m) | |
91 | 147 |
92 each = lambda step, repeat = inf: each0(step, repeat) + step | 148 each = lambda step, repeat = inf: each0(step, repeat) + step |
93 each0 = IntervalSchedule | 149 each0 = IntervalSchedule |
94 | 150 |
95 | 151 |
96 class RangeSchedule(Schedule): | 152 class RangeSchedule(Schedule): |
97 def __init__(self, low = None, high = None): | 153 def __init__(self, low = None, high = None): |
98 self.low = low or -inf | 154 self.low = low or -inf |
99 self.high = high or inf | 155 self.high = high or inf |
100 def __call__(self, t1, t2): | 156 def __call__(self, time): |
101 return self.low <= t1 <= self.high \ | 157 if isinstance(time, int): |
102 or self.low <= t2 <= self.high | 158 return self.low <= time <= self.high |
159 else: | |
160 t1, t2 = time | |
161 return self.low <= t1 <= self.high \ | |
162 or self.low <= t2 <= self.high | |
103 | 163 |
104 inrange = RangeSchedule | 164 inrange = RangeSchedule |
105 | 165 |
106 | 166 |
107 class ListSchedule(Schedule): | 167 class ListSchedule(Schedule): |
108 def __init__(self, *schedules): | 168 def __init__(self, *schedules): |
109 self.schedules = schedules | 169 self.schedules = schedules |
110 def __call__(self, t1, t2): | 170 def __call__(self, time): |
111 for t in self.schedules: | 171 if isinstance(time, int): |
112 if t1 <= t <= t2: | 172 return time in self.schedules |
113 return True | 173 else: |
174 for t in self.schedules: | |
175 if t1 <= t <= t2: | |
176 return True | |
114 return False | 177 return False |
115 | 178 |
116 at = ListSchedule | 179 at = ListSchedule |
117 at_start = at(-inf) | |
118 at_end = at(inf) | |
119 | 180 |
120 | 181 |
121 ############## | 182 ############## |
122 ### RUNNER ### | 183 ### PLUGIN ### |
123 ############## | 184 ############## |
124 | 185 |
125 class scratchpad: | 186 class Plugin(object): |
126 pass | 187 |
127 | 188 def attach(self, scheduler): |
128 # # ORIGINAL RUNNER, NO TIMELINES | 189 c = copy(self) |
129 # def runner(master, plugins): | 190 c.scheduler = scheduler |
130 # """ | 191 return c |
131 # master is a function which is in charge of the "this" object. It | 192 |
132 # is in charge of updating the t1, t2 and done fields, It must | 193 def __call__(self, event): |
133 # take a single argument, this. | 194 raise NotImplementedError("Implement this!") |
134 | 195 |
135 # plugins is a list of (schedule, function) pairs. In-between each | 196 def fire(self, type, **attributes): |
136 # execution of the master function, as well as at the very | 197 event = Event(type, issuer = self, **attributes) |
137 # beginning and at the very end, the schedule will be consulted | 198 self.scheduler.queue(event) |
138 # for the time range [t1, t2], and if there is a match, the | 199 |
139 # function will be called with this as the argument. The order | 200 class FnPlugin(Plugin): |
140 # in which the functions are provided is respected. | 201 |
141 | 202 def __init__(self, function): |
142 # Note: the reason why we use t1 and t2 instead of just t is that it | 203 self.function = function |
143 # gives the master function the ability to run several iterations at | 204 |
144 # once without consulting any plugins. In that situation, t1 and t2 | 205 def __call__(self, event): |
145 # represent a range, and the schedule must determine if there would | 206 return self.function(self, event) |
146 # have been an event in that range (we do not distinguish between a | 207 |
147 # single event and multiple events). | 208 class DispatchPlugin(Plugin): |
148 | 209 |
149 # For instance, if one is training using minibatches, one could set | 210 def __call__(self, event): |
150 # t1 and t2 to the index of the lower and higher examples, and the | 211 getattr(self, "on_" + event.type, self.generic)(event) |
151 # plugins' schedules would be given according to how many examples | 212 |
152 # were seen rather than how many minibatches were processed. | 213 def generic(self, event): |
153 | 214 return |
154 # Another possibility is to use real time - t1 would be the time | 215 |
155 # before the execution of the master function, t2 the time after | 216 |
156 # (in, say, milliseconds). Then you can define plugins that run | 217 ################# |
157 # every second or every minute, but only in-between two training | 218 ### SCHEDULER ### |
158 # iterations. | 219 ################# |
159 # """ | 220 |
160 | 221 class Scheduler(object): |
161 # this = scratchpad() | 222 |
162 # this.t1 = -inf | 223 def __init__(self): |
163 # this.t2 = -inf | 224 self.plugins = [] |
164 # this.started = False | 225 self.categorized = defaultdict(list) |
165 # this.done = False | 226 self.event_queue = deque() |
166 # while True: | 227 |
167 # for schedule, function in plugins: | 228 def __call__(self): |
168 # if schedule(this.t1, this.t2): | 229 i = 0 |
169 # function(this) | 230 evq = self.event_queue |
170 # if this.done: | 231 self.queue(Event("begin", issuer = self)) |
171 # break | 232 while True: |
172 # master(this) | 233 self.queue(Event("tick", issuer = self, time = i)) |
173 # this.started = True | 234 while evq: |
174 # if this.done: | 235 event = evq.popleft() |
175 # break | 236 candidates = self.categorized[event.type] + self.categorized[None] |
176 # this.t1 = inf | 237 for event_template, plugin in candidates: |
177 # this.t2 = inf | 238 if event.match(event_template): |
178 # for schedule, function in plugins: | 239 plugin(event) # note: the plugin might queue more events |
179 # if schedule(this.t1, this.t2): | 240 if event.type == "terminate": |
180 # function(this) | 241 return |
181 | 242 i += 1 |
182 | 243 |
183 | 244 def schedule_plugin(self, event_template, plugin): |
184 | 245 plugin = plugin.attach(self) |
185 def runner(main, plugins): | 246 if isinstance(event_template, Matcher) or isinstance(event_template.type, Matcher): |
186 """ | 247 # These plugins may execute upon any event type |
187 :param main: A function which must take a single argument, | 248 self.categorized[None].append((event_template, plugin)) |
188 ``this``. The ``this`` argument contains a settable ``done`` | 249 else: |
189 flag indicating whether the iterations should keep going or | 250 self.categorized[event_template.type].append((event_template, plugin)) |
190 not, as well as a flag indicating whether this is the first | 251 self.plugins.append((event_template, plugin)) |
191 time runner() is calling main(). main() may store whatever it | 252 |
192 wants in ``this``. It may also add one or more timelines in | 253 def queue(self, event): |
193 ``this.timelines[timeline_name]``, which plugins can exploit. | 254 self.event_queue.append(event) |
194 | 255 |
195 :param plugins: A list of (schedule, timeline, function) | 256 |
196 tuples. In-between each execution of the main function, as | 257 |
197 well as at the very beginning and at the very end, the | 258 |
198 schedule will be consulted for the time range [t1, t2] from | 259 @FnPlugin |
199 the appropriate timeline, and if there is a match, the | 260 def printer(self, event): |
200 function will be called with ``this`` as the argument. The | 261 print event |
201 order in which the functions are provided is respected. | 262 |
202 | 263 @FnPlugin |
203 For any plugin, the timeline can be | 264 def stopper(self, event): |
204 * 'iterations', where t1 == t2 == the iteration number | 265 self.fire("terminate") |
205 * 'real_time', where t1 and t2 mark the start of the last | 266 |
206 loop and the start of the current loop, in seconds since | 267 @FnPlugin |
207 the beginning of training (includes time spent in plugins) | 268 def byebye(self, event): |
208 * 'algorithm_time', where t1 and t2 mark the start and end | 269 print "bye bye!" |
209 of the last iteration of the main function (does not | 270 |
210 include time spent in plugins) | 271 |
211 * A main function specific timeline. | 272 @FnPlugin |
212 | 273 def waiter(self, event): |
213 At the very beginning, the time for all timelines is | 274 time.sleep(0.1) |
214 -infinity, at the very end it is +infinity. | 275 |
215 """ | 276 # @FnPlugin |
216 start_time = time.time() | 277 # def timer(self, event): |
217 | 278 # if not hasattr(self, 'previous'): |
218 this = scratchpad() | 279 # self.beginning = time.time() |
219 | 280 # self.previous = 0 |
220 this.timelines = defaultdict(lambda: [-inf, -inf]) | 281 # now = time.time() - self.beginning |
221 realt = this.timelines['real_time'] | 282 # inow = int(now) |
222 algot = this.timelines['algorithm_time'] | 283 # if inow > self.previous: |
223 itert = this.timelines['iterations'] | 284 # self.fire("second", time = inow) |
224 | 285 # self.previous = now |
225 this.started = False | 286 |
226 this.done = False | 287 class Timer(DispatchPlugin): |
227 | 288 |
228 while True: | 289 def on_begin(self, event): |
229 | 290 self.beginning = time.time() |
230 for schedule, timeline, function in plugins: | 291 self.previous = 0 |
231 if schedule(*this.timelines[timeline]): | 292 |
232 function(this) | 293 def on_tick(self, event): |
233 if this.done: | 294 now = time.time() - self.beginning |
234 break | 295 inow = int(now) |
235 | 296 if inow > self.previous: |
236 t1 = time.time() | 297 self.fire("second", time = inow) |
237 main(this) | 298 self.previous = now |
238 t2 = time.time() | 299 |
239 | 300 |
240 if not this.started: | 301 |
241 realt[:] = [0, 0] | 302 sch = Scheduler() |
242 algot[:] = [0, 0] | 303 |
243 itert[:] = [-1, -1] | 304 |
244 realt[:] = [realt[1], t2 - start_time] | 305 sch.schedule_plugin(all_events, Timer()) |
245 algot[:] = [algot[1], algot[1] + (t2 - t1)] | 306 sch.schedule_plugin(Event("tick"), waiter) # this means: execute the waiter plugin (a delay) on every "tick" event. Is it confusing to use Event(...)? |
246 itert[:] = [itert[0] + 1, itert[1] + 1] | 307 sch.schedule_plugin(Event("second"), printer) |
247 | 308 |
248 this.started = True | 309 # sch.schedule_plugin(all_events, printer) |
249 if this.done: | 310 |
250 break | 311 sch.schedule_plugin(Event("tick", time = at(100)), stopper) |
251 | 312 sch.schedule_plugin(Event("terminate"), byebye) |
252 this.timelines = defaultdict(lambda: [inf, inf]) | 313 |
253 | 314 sch() |
254 for schedule, timeline, function in plugins: | |
255 if schedule(*this.timelines[timeline]): | |
256 function(this) | |
257 | |
258 | |
259 | |
260 | |
261 | |
262 ################ | |
263 ### SHOWCASE ### | |
264 ################ | |
265 | |
266 def main(this): | |
267 if not this.started: | |
268 this.error = 1.0 | |
269 # note: runner will automatically set this.started to true | |
270 else: | |
271 this.error /= 1.1 | |
272 | |
273 | |
274 def welcome(this): | |
275 print "Let's start!" | |
276 | |
277 def print_iter(this): | |
278 print "Now running iteration #%i" % this.timelines['iterations'][0] | |
279 | |
280 def print_error(this): | |
281 print "The error rate is %s" % this.error | |
282 | |
283 def maybe_stop(this): | |
284 thr = 0.01 | |
285 if this.error < thr: | |
286 print "Error is below the threshold: %s <= %s" % (this.error, thr) | |
287 this.done = True | |
288 | |
289 def wait_a_bit(this): | |
290 time.sleep(1./37) | |
291 | |
292 def printer(txt): | |
293 def f(this): | |
294 print txt | |
295 return f | |
296 | |
297 def stop_this_madness(this): | |
298 this.done = True | |
299 | |
300 def byebye(this): | |
301 print "Bye bye!" | |
302 | |
303 runner(main = main, | |
304 plugins = [# At the very beginning, print a welcome message | |
305 (at_start, 'iterations', welcome), | |
306 # Each iteration from 1 to 10 inclusive, OR each multiple of 10 | |
307 # (except 0 - each() excludes 0, each0() includes it) | |
308 # print the error | |
309 (inrange(1, 10) | each(10), 'iterations', print_error), | |
310 # Each multiple of 10, check for stopping condition | |
311 (each(10), 'iterations', maybe_stop), | |
312 # At iteration 1000, if we ever get that far, just stop | |
313 (at(1000), 'iterations', stop_this_madness), | |
314 # Wait a bit | |
315 (each(1), 'iterations', wait_a_bit), | |
316 # Print bonk each second of real time | |
317 (each(1), 'real_time', printer('BONK')), | |
318 # Print thunk each second of time in main() (main() | |
319 # is too fast, so this does not happen for many | |
320 # iterations) | |
321 (each(1), 'algorithm_time', printer('THUNK')), | |
322 # Announce the next iteration | |
323 (each0(1), 'iterations', print_iter), | |
324 # At the very end, display a message | |
325 (at_end, 'iterations', byebye)]) | |
326 | |
327 |