annotate stopper.py @ 179:9911d2cc3c01

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 13 May 2008 15:14:04 -0400
parents 4090779e39a9
children bd728c83faff
rev   line source
178
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
1 """Early stopping iterators
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
2
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
3 The idea here is to supply early-stopping heuristics that can be used in the
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 form:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
5
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6 stopper = SomeEarlyStopper()
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
7
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
8 for i in stopper():
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
9 # train from data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
10 if i.set_score:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
11 i.score = validation_score
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
12
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
13
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
14 So far I only have one heuristic, so maybe this won't scale.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
15 """
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
16
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
17 class Stopper(object):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
18
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
19 def train(self, data, update_rows_fn, update, validate, save=None):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
20 """Return the best model trained on data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
21
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
22 Parameters:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
23 data - a thing that accepts getitem(<list of int64>), or a tuple of such things
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
24 update_rows_fn - fn : int --> <list or tensor of int>
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
25 update - fn: update an internal model from elements of data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
26 validate - fn: evaluate an internal model based on elements of data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
27 save - fn: return a copy of the internal model
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
28
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
29 The body of this function exhausts the <self> iterator, and trains a
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
30 model using early stopping in the process.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
31 """
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
32
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
33 best = None
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
34 for stp in self:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
35 i = stp.iter
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
36
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
37 # call update on some training set rows
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
38 t_rows = update_rows_fn(i)
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
39 if isinstance(data, (tuple, list)):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
40 update(*[d[t_rows] for d in data])
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
41 else:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
42 update(data[t_rows])
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
43
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
44 if stp.set_score:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
45 stp.score = validate()
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
46 if (stp.score < stp.best_score) and save:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
47 best = save()
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
48 return best
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
49
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
50
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
51 class ICML08Stopper(Stopper):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
52 @staticmethod
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
53 def icml08(ntrain, batchsize):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
54 """Some setting similar to what I used for ICML08 submission"""
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
55 #TODO: what did I actually use? put that in here.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
56 return ICML08Stopper(30*ntrain/batchsize,
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
57 ntrain/batchsize, 0.96, 2.0, 100000000)
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
58
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
59 def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
60 self.initial_wait = i_wait
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
61 self.set_score_interval = v_int
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
62 self.min_improvement = min_improvement
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
63 self.patience = patience
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
64 self.hard_limit = hard_limit
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
65
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
66 self.best_score = float('inf')
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
67 self.best_iter = -1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
68 self.iter = -1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
69
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
70 self.set_score = False
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
71 self.score = None
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
72
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
73 def __iter__(self):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
74 return self
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
75
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
76 E_set_score = 'when iter.set_score is True, caller must assign a score to iter.score'
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
77 def next(self):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
78 if self.set_score: #left over from last time
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
79 if self.score is None:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
80 raise Exception(ICML08Stopper.E_set_score)
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
81 if self.score < (self.best_score * self.min_improvement):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
82 (self.best_score, self.best_iter) = (self.score, self.iter)
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
83 self.score = None #un-set it
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
84
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
85
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
86 starting = self.iter < self.initial_wait
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
87 waiting = self.iter < (self.patience * self.best_iter)
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
88 if starting or waiting:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
89 # continue to iterate
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
90 self.iter += 1
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
91 if self.iter == self.hard_limit:
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
92 raise StopIteration
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
93 self.set_score = (self.iter % self.set_score_interval == 0)
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
94 return self
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
95
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
96 raise StopIteration
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
97
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
98
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
99 class NStages(ICML08Stopper):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
100 """Run for a fixed number of steps, checking validation set every so
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
101 often."""
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
102 def __init__(self, hard_limit, v_int):
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
103 ICML08Stopper.__init__(self, hard_limit, v_int, 1.0, 1.0, hard_limit)
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
104
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
105 #TODO: could optimize next() function. Most of what's in ICML08Stopper.next()
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
106 #is not necessary
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
107
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
108