178
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
1 """Early stopping iterators
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
2
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
3 The idea here is to supply early-stopping heuristics that can be used in the
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
4 form:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
5
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
6 stopper = SomeEarlyStopper()
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
7
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
8 for i in stopper():
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
9 # train from data
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
10 if i.set_score:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
11 i.score = validation_score
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
12
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
13
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
14 So far I only have one heuristic, so maybe this won't scale.
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
15 """
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
16
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
17 class Stopper(object):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
18
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
19 def train(self, data, update_rows_fn, update, validate, save=None):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
20 """Return the best model trained on data
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
21
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
22 Parameters:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
23 data - a thing that accepts getitem(<list of int64>), or a tuple of such things
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
24 update_rows_fn - fn : int --> <list or tensor of int>
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
25 update - fn: update an internal model from elements of data
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
26 validate - fn: evaluate an internal model based on elements of data
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
27 save - fn: return a copy of the internal model
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
28
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
29 The body of this function exhausts the <self> iterator, and trains a
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
30 model using early stopping in the process.
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
31 """
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
32
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
33 best = None
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
34 for stp in self:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
35 i = stp.iter
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
36
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
37 # call update on some training set rows
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
38 t_rows = update_rows_fn(i)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
39 if isinstance(data, (tuple, list)):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
40 update(*[d[t_rows] for d in data])
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
41 else:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
42 update(data[t_rows])
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
43
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
44 if stp.set_score:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
45 stp.score = validate()
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
46 if (stp.score < stp.best_score) and save:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
47 best = save()
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
48 return best
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
49
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
50
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
51 class ICML08Stopper(Stopper):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
52 @staticmethod
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
53 def icml08(ntrain, batchsize):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
54 """Some setting similar to what I used for ICML08 submission"""
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
55 #TODO: what did I actually use? put that in here.
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
56 return ICML08Stopper(30*ntrain/batchsize,
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
57 ntrain/batchsize, 0.96, 2.0, 100000000)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
58
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
59 def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
60 self.initial_wait = i_wait
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
61 self.set_score_interval = v_int
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
62 self.min_improvement = min_improvement
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
63 self.patience = patience
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
64 self.hard_limit = hard_limit
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
65
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
66 self.best_score = float('inf')
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
67 self.best_iter = -1
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
68 self.iter = -1
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
69
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
70 self.set_score = False
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
71 self.score = None
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
72
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
73 def __iter__(self):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
74 return self
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
75
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
76 E_set_score = 'when iter.set_score is True, caller must assign a score to iter.score'
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
77 def next(self):
|
211
bd728c83faff
in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
diff
changeset
|
78
|
213
cc96f93a7810
previous commit was supposed to concern only one file, dataset.py, try to undo my other changes with this commit (nothing was broken though, just useless debugging prints)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
diff
changeset
|
79 #print 'ICML08 stopper, were doing a next'
|
211
bd728c83faff
in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
diff
changeset
|
80
|
178
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
81 if self.set_score: #left over from last time
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
82 if self.score is None:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
83 raise Exception(ICML08Stopper.E_set_score)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
84 if self.score < (self.best_score * self.min_improvement):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
85 (self.best_score, self.best_iter) = (self.score, self.iter)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
86 self.score = None #un-set it
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
87
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
88
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
89 starting = self.iter < self.initial_wait
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
90 waiting = self.iter < (self.patience * self.best_iter)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
91 if starting or waiting:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
92 # continue to iterate
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
93 self.iter += 1
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
94 if self.iter == self.hard_limit:
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
95 raise StopIteration
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
96 self.set_score = (self.iter % self.set_score_interval == 0)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
97 return self
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
98
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
99 raise StopIteration
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
100
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
101
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
102 class NStages(ICML08Stopper):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
103 """Run for a fixed number of steps, checking validation set every so
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
104 often."""
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
105 def __init__(self, hard_limit, v_int):
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
106 ICML08Stopper.__init__(self, hard_limit, v_int, 1.0, 1.0, hard_limit)
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
107
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
108 #TODO: could optimize next() function. Most of what's in ICML08Stopper.next()
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
109 #is not necessary
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
110
|
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
111
|