view doc/v2_planning/plugin.py @ 1128:03b41a79bd60

coding_style: Replies to James' questions / comments
author Olivier Delalleau <delallea@iro>
date Wed, 15 Sep 2010 12:06:09 -0400
parents 8cc324f388ba
children a1957faecc9b
line wrap: on
line source


import time
from collections import defaultdict

inf = float('inf')

################
### SCHEDULE ###
################

class Schedule(object):
    def __add__(self, i):
        return OffsetSchedule(self, i)
    def __or__(self, s):
        return UnionSchedule(self, to_schedule(s))
    def __and__(self, s):
        return IntersectionSchedule(self, to_schedule(s))
    def __sub__(self, i):
        return OffsetSchedule(self, -i)
    def __ror__(self, s):
        return UnionSchedule(to_schedule(s), self)
    def __rand__(self, s):
        return IntersectionSchedule(to_schedule(s), self)
    def __invert__(self):
        return NegatedSchedule(self)

def to_schedule(x):
    if x in (None, False):
        return never
    if x is True:
        return always
    elif isinstance(x, (list, tuple)):
        return reduce(UnionSchedule, x)
    else:
        return x


class ScheduleMix(Schedule):
    __n__ = None
    def __init__(self, *subschedules):
        assert (not self.__n__) or len(subschedules) == self.__n__
        self.subschedules = map(to_schedule, subschedules)

class UnionSchedule(ScheduleMix):
    def __call__(self, t1, t2):
        return any(s(t1, t2) for s in self.subschedules)

class IntersectionSchedule(ScheduleMix):
    def __call__(self, t1, t2):
        return all(s(t1, t2) for s in self.subschedules)

class DifferenceSchedule(ScheduleMix):
    __n__ = 2
    def __call__(self, t1, t2):
        return self.subschedules[0](t1, t2) and not self.subschedules[1](t1, t2)

class NegatedSchedule(ScheduleMix):
    __n__ = 1
    def __call__(self, t1, t2):
        return not self.subschedules[0](t1, t2)

class OffsetSchedule(Schedule):
    def __init__(self, schedule, offset):
        self.schedule = schedule
        self.offset = offset
    def __call__(self, t1, t2):
        return self.schedule(t1 - self.offset, t2 - self.offset)


class AlwaysSchedule(Schedule):
    def __call__(self, t1, t2):
        return True

always = AlwaysSchedule()
never = ~always

class IntervalSchedule(Schedule):
    def __init__(self, step, repeat = inf):
        self.step = step
        self.upper_bound = step * (repeat - 1)
    def __call__(self, t1, t2):
        if t2 < 0 or t1 > self.upper_bound:
            return False
        diff = t2 - t1
        t1m = t1 % self.step
        t2m = t2 % self.step
        return (diff >= self.step
                or t1m == 0
                or t2m == 0
                or t1m > t2m)

each = lambda step, repeat = inf: each0(step, repeat) + step
each0 = IntervalSchedule


class RangeSchedule(Schedule):
    def __init__(self, low = None, high = None):
        self.low = low or -inf
        self.high = high or inf
    def __call__(self, t1, t2):
        return self.low <= t1 <= self.high \
            or self.low <= t2 <= self.high

inrange = RangeSchedule    


class ListSchedule(Schedule):
    def __init__(self, *schedules):
        self.schedules = schedules
    def __call__(self, t1, t2):
        for t in self.schedules:
            if t1 <= t <= t2:
                return True
        return False

at = ListSchedule
at_start = at(-inf)
at_end = at(inf)


##############
### RUNNER ###
##############

class scratchpad:
    pass

# # ORIGINAL RUNNER, NO TIMELINES
# def runner(master, plugins):
#     """
#     master is a function which is in charge of the "this" object.  It
#         is in charge of updating the t1, t2 and done fields, It must
#         take a single argument, this.

#     plugins is a list of (schedule, function) pairs. In-between each
#         execution of the master function, as well as at the very
#         beginning and at the very end, the schedule will be consulted
#         for the time range [t1, t2], and if there is a match, the
#         function will be called with this as the argument. The order
#         in which the functions are provided is respected.

#     Note: the reason why we use t1 and t2 instead of just t is that it
#     gives the master function the ability to run several iterations at
#     once without consulting any plugins. In that situation, t1 and t2
#     represent a range, and the schedule must determine if there would
#     have been an event in that range (we do not distinguish between a
#     single event and multiple events).

#     For instance, if one is training using minibatches, one could set
#     t1 and t2 to the index of the lower and higher examples, and the
#     plugins' schedules would be given according to how many examples
#     were seen rather than how many minibatches were processed.

#     Another possibility is to use real time - t1 would be the time
#     before the execution of the master function, t2 the time after
#     (in, say, milliseconds). Then you can define plugins that run
#     every second or every minute, but only in-between two training
#     iterations.
#     """

#     this = scratchpad()
#     this.t1 = -inf
#     this.t2 = -inf
#     this.started = False
#     this.done = False
#     while True:
#         for schedule, function in plugins:
#             if schedule(this.t1, this.t2):
#                 function(this)
#                 if this.done:
#                     break
#         master(this)
#         this.started = True
#         if this.done:
#             break
#     this.t1 = inf
#     this.t2 = inf
#     for schedule, function in plugins:
#         if schedule(this.t1, this.t2):
#             function(this)




def runner(main, plugins):
    """
    :param main: A function which must take a single argument,
        ``this``. The ``this`` argument contains a settable ``done``
        flag indicating whether the iterations should keep going or
        not, as well as a flag indicating whether this is the first
        time runner() is calling main(). main() may store whatever it
        wants in ``this``. It may also add one or more timelines in
        ``this.timelines[timeline_name]``, which plugins can exploit.

    :param plugins: A list of (schedule, timeline, function)
        tuples. In-between each execution of the main function, as
        well as at the very beginning and at the very end, the
        schedule will be consulted for the time range [t1, t2] from
        the appropriate timeline, and if there is a match, the
        function will be called with ``this`` as the argument. The
        order in which the functions are provided is respected.

        For any plugin, the timeline can be
        * 'iterations', where t1 == t2 == the iteration number
        * 'real_time', where t1 and t2 mark the start of the last
          loop and the start of the current loop, in seconds since
          the beginning of training (includes time spent in plugins)
        * 'algorithm_time', where t1 and t2 mark the start and end
          of the last iteration of the main function (does not
          include time spent in plugins)
        * A main function specific timeline.

        At the very beginning, the time for all timelines is
        -infinity, at the very end it is +infinity.
    """
    start_time = time.time()

    this = scratchpad()

    this.timelines = defaultdict(lambda: [-inf, -inf])
    realt = this.timelines['real_time']
    algot = this.timelines['algorithm_time']
    itert = this.timelines['iterations']

    this.started = False
    this.done = False

    while True:

        for schedule, timeline, function in plugins:
            if schedule(*this.timelines[timeline]):
                function(this)
                if this.done:
                    break

        t1 = time.time()
        main(this)
        t2 = time.time()

        if not this.started:
            realt[:] = [0, 0]
            algot[:] = [0, 0]
            itert[:] = [-1, -1]
        realt[:] = [realt[1], t2 - start_time]
        algot[:] = [algot[1], algot[1] + (t2 - t1)]
        itert[:] = [itert[0] + 1, itert[1] + 1]

        this.started = True
        if this.done:
            break

    this.timelines = defaultdict(lambda: [inf, inf])

    for schedule, timeline, function in plugins:
        if schedule(*this.timelines[timeline]):
            function(this)





################
### SHOWCASE ###
################

def main(this):
    if not this.started:
        this.error = 1.0
        # note: runner will automatically set this.started to true
    else:
        this.error /= 1.1


def welcome(this):
    print "Let's start!"

def print_iter(this):
    print "Now running iteration #%i" % this.timelines['iterations'][0]

def print_error(this):
    print "The error rate is %s" % this.error

def maybe_stop(this):
    thr = 0.01
    if this.error < thr:
        print "Error is below the threshold: %s <= %s" % (this.error, thr)
        this.done = True

def wait_a_bit(this):
    time.sleep(1./37)

def printer(txt):
    def f(this):
        print txt
    return f

def stop_this_madness(this):
    this.done = True

def byebye(this):
    print "Bye bye!"

runner(main = main,
       plugins = [# At the very beginning, print a welcome message
                  (at_start, 'iterations', welcome),
                  # Each iteration from 1 to 10 inclusive, OR each multiple of 10
                  # (except 0 - each() excludes 0, each0() includes it)
                  # print the error
                  (inrange(1, 10) | each(10), 'iterations',  print_error),
                  # Each multiple of 10, check for stopping condition
                  (each(10), 'iterations',  maybe_stop),
                  # At iteration 1000, if we ever get that far, just stop
                  (at(1000), 'iterations',  stop_this_madness),
                  # Wait a bit
                  (each(1), 'iterations',  wait_a_bit),
                  # Print bonk each second of real time
                  (each(1), 'real_time',  printer('BONK')),
                  # Print thunk each second of time in main() (main()
                  # is too fast, so this does not happen for many
                  # iterations)
                  (each(1), 'algorithm_time',  printer('THUNK')),
                  # Announce the next iteration
                  (each0(1), 'iterations',  print_iter),
                  # At the very end, display a message
                  (at_end, 'iterations',  byebye)])