view stopper.py @ 453:ce6b4fd3ab29

Fixed typo in help
author delallea@valhalla.apstat.com
date Thu, 04 Sep 2008 13:48:47 -0400
parents cc96f93a7810
children 40c8a46b3da7
line wrap: on
line source

"""Early stopping iterators

The idea here is to supply early-stopping heuristics that can be used in the
form:

    stopper = SomeEarlyStopper()

    for i in stopper():
        # train from data
        if i.set_score:
            i.score = validation_score


So far I only have one heuristic, so maybe this won't scale.
"""

class Stopper(object):

    def train(self, data, update_rows_fn, update, validate, save=None):
        """Return the best model trained on data

        Parameters:
        data - a thing that accepts getitem(<list of int64>), or a tuple of such things
        update_rows_fn - fn : int --> <list or tensor of int>
        update - fn: update an internal model from elements of data
        validate - fn: evaluate an internal model based on elements of data
        save - fn: return a copy of the internal model

        The body of this function exhausts the <self> iterator, and trains a
        model using early stopping in the process.
        """

        best = None
        for stp in self:
            i = stp.iter

            # call update on some training set rows
            t_rows = update_rows_fn(i)
            if isinstance(data, (tuple, list)):
                update(*[d[t_rows] for d in data])
            else:
                update(data[t_rows])

            if stp.set_score:
                stp.score = validate()
                if (stp.score < stp.best_score) and save:
                    best = save()
        return best


class ICML08Stopper(Stopper):
    @staticmethod
    def icml08(ntrain, batchsize):
        """Some setting similar to what I used for ICML08 submission"""
        #TODO: what did I actually use? put that in here.
        return ICML08Stopper(30*ntrain/batchsize,
                ntrain/batchsize, 0.96, 2.0, 100000000)

    def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit):
        self.initial_wait = i_wait
        self.set_score_interval = v_int
        self.min_improvement = min_improvement
        self.patience = patience
        self.hard_limit = hard_limit

        self.best_score = float('inf')
        self.best_iter = -1
        self.iter = -1

        self.set_score = False
        self.score = None

    def __iter__(self):
        return self

    E_set_score = 'when iter.set_score is True, caller must assign a score to iter.score'
    def next(self):

        #print 'ICML08 stopper, were doing a next'

        if self.set_score: #left over from last time
            if self.score is None:
                raise Exception(ICML08Stopper.E_set_score)
            if self.score < (self.best_score * self.min_improvement):
                (self.best_score, self.best_iter) = (self.score, self.iter)
            self.score = None #un-set it


        starting = self.iter < self.initial_wait
        waiting = self.iter < (self.patience * self.best_iter)
        if starting or waiting:
            # continue to iterate
            self.iter += 1
            if self.iter == self.hard_limit:
                raise StopIteration
            self.set_score = (self.iter % self.set_score_interval == 0)
            return self

        raise StopIteration


class NStages(ICML08Stopper):
    """Run for a fixed number of steps, checking validation set every so
    often."""
    def __init__(self, hard_limit, v_int):
        ICML08Stopper.__init__(self, hard_limit, v_int, 1.0, 1.0, hard_limit)

    #TODO: could optimize next() function. Most of what's in ICML08Stopper.next()
    #is not necessary