178
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
1 """Early stopping iterators
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
2
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
3 The idea here is to supply early-stopping heuristics that can be used in the
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
4 form:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
5
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
6 stopper = SomeEarlyStopper()
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
7
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
8 for i in stopper():
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
9 # train from data
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
10 if i.set_score:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
11 i.score = validation_score
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
12
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
13
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
14 So far I only have one heuristic, so maybe this won't scale.
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
15 """
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
16
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
17 class Stopper(object):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
18
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
19 def train(self, data, update_rows_fn, update, validate, save=None):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
20 """Return the best model trained on data
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
21
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
22 Parameters:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
23 data - a thing that accepts getitem(<list of int64>), or a tuple of such things
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
24 update_rows_fn - fn : int --> <list or tensor of int>
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
25 update - fn: update an internal model from elements of data
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
26 validate - fn: evaluate an internal model based on elements of data
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
27 save - fn: return a copy of the internal model
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
28
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
29 The body of this function exhausts the <self> iterator, and trains a
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
30 model using early stopping in the process.
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
31 """
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
32
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
33 best = None
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
34 for stp in self:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
35 i = stp.iter
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
36
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
37 # call update on some training set rows
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
38 t_rows = update_rows_fn(i)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
39 if isinstance(data, (tuple, list)):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
40 update(*[d[t_rows] for d in data])
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
41 else:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
42 update(data[t_rows])
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
43
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
44 if stp.set_score:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
45 stp.score = validate()
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
46 if (stp.score < stp.best_score) and save:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
47 best = save()
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
48 return best
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
49
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
50
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
51 class ICML08Stopper(Stopper):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
52 @staticmethod
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
53 def icml08(ntrain, batchsize):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
54 """Some setting similar to what I used for ICML08 submission"""
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
55 #TODO: what did I actually use? put that in here.
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
56 return ICML08Stopper(30*ntrain/batchsize,
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
57 ntrain/batchsize, 0.96, 2.0, 100000000)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
58
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
59 def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
60 self.initial_wait = i_wait
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
61 self.set_score_interval = v_int
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
62 self.min_improvement = min_improvement
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
63 self.patience = patience
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
64 self.hard_limit = hard_limit
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
65
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
66 self.best_score = float('inf')
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
67 self.best_iter = -1
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
68 self.iter = -1
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
69
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
70 self.set_score = False
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
71 self.score = None
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
72
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
73 def __iter__(self):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
74 return self
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
75
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
76 E_set_score = 'when iter.set_score is True, caller must assign a score to iter.score'
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
77 def next(self):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
78 if self.set_score: #left over from last time
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
79 if self.score is None:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
80 raise Exception(ICML08Stopper.E_set_score)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
81 if self.score < (self.best_score * self.min_improvement):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
82 (self.best_score, self.best_iter) = (self.score, self.iter)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
83 self.score = None #un-set it
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
84
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
85
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
86 starting = self.iter < self.initial_wait
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
87 waiting = self.iter < (self.patience * self.best_iter)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
88 if starting or waiting:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
89 # continue to iterate
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
90 self.iter += 1
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
91 if self.iter == self.hard_limit:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
92 raise StopIteration
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
93 self.set_score = (self.iter % self.set_score_interval == 0)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
94 return self
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
95
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
96 raise StopIteration
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
97
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
98
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
99 class NStages(ICML08Stopper):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
100 """Run for a fixed number of steps, checking validation set every so
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
101 often."""
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
102 def __init__(self, hard_limit, v_int):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
103 ICML08Stopper.__init__(self, hard_limit, v_int, 1.0, 1.0, hard_limit)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
104
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
105 #TODO: could optimize next() function. Most of what's in ICML08Stopper.next()
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
106 #is not necessary
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
107
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
108
|