diff stopper.py @ 178:4090779e39a9

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 13 May 2008 15:12:20 -0400
parents
children bd728c83faff
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stopper.py	Tue May 13 15:12:20 2008 -0400
@@ -0,0 +1,108 @@
+"""Early stopping iterators
+
+The idea here is to supply early-stopping heuristics that can be used in the
+form:
+
+    stopper = SomeEarlyStopper()
+
+    for i in stopper():
+        # train from data
+        if i.set_score:
+            i.score = validation_score
+
+
+So far I only have one heuristic, so maybe this won't scale.
+"""
+
+class Stopper(object):
+
+    def train(self, data, update_rows_fn, update, validate, save=None):
+        """Return the best model trained on data
+
+        Parameters:
+        data - a thing that accepts getitem(<list of int64>), or a tuple of such things
+        update_rows_fn - fn : int --> <list or tensor of int>
+        update - fn: update an internal model from elements of data
+        validate - fn: evaluate an internal model based on elements of data
+        save - fn: return a copy of the internal model
+
+        The body of this function exhausts the <self> iterator, and trains a
+        model using early stopping in the process.
+        """
+
+        best = None
+        for stp in self:
+            i = stp.iter
+
+            # call update on some training set rows
+            t_rows = update_rows_fn(i)
+            if isinstance(data, (tuple, list)):
+                update(*[d[t_rows] for d in data])
+            else:
+                update(data[t_rows])
+
+            if stp.set_score:
+                stp.score = validate()
+                if (stp.score < stp.best_score) and save:
+                    best = save()
+        return best
+
+
+class ICML08Stopper(Stopper):
+    @staticmethod
+    def icml08(ntrain, batchsize):
+        """Some setting similar to what I used for ICML08 submission"""
+        #TODO: what did I actually use? put that in here.
+        return ICML08Stopper(30*ntrain/batchsize,
+                ntrain/batchsize, 0.96, 2.0, 100000000)
+
+    def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit):
+        self.initial_wait = i_wait
+        self.set_score_interval = v_int
+        self.min_improvement = min_improvement
+        self.patience = patience
+        self.hard_limit = hard_limit
+
+        self.best_score = float('inf')
+        self.best_iter = -1
+        self.iter = -1
+
+        self.set_score = False
+        self.score = None
+
+    def __iter__(self):
+        return self
+
+    E_set_score = 'when iter.set_score is True, caller must assign a score to iter.score'
+    def next(self):
+        if self.set_score: #left over from last time
+            if self.score is None:
+                raise Exception(ICML08Stopper.E_set_score)
+            if self.score < (self.best_score * self.min_improvement):
+                (self.best_score, self.best_iter) = (self.score, self.iter)
+            self.score = None #un-set it
+
+
+        starting = self.iter < self.initial_wait
+        waiting = self.iter < (self.patience * self.best_iter)
+        if starting or waiting:
+            # continue to iterate
+            self.iter += 1
+            if self.iter == self.hard_limit:
+                raise StopIteration
+            self.set_score = (self.iter % self.set_score_interval == 0)
+            return self
+
+        raise StopIteration
+
+
+class NStages(ICML08Stopper):
+    """Run for a fixed number of steps, checking validation set every so
+    often."""
+    def __init__(self, hard_limit, v_int):
+        ICML08Stopper.__init__(self, hard_limit, v_int, 1.0, 1.0, hard_limit)
+
+    #TODO: could optimize next() function. Most of what's in ICML08Stopper.next()
+    #is not necessary
+
+