comparison stopper.py @ 178:4090779e39a9

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 13 May 2008 15:12:20 -0400
parents
children bd728c83faff
comparison
equal deleted inserted replaced
177:69759976b3ac 178:4090779e39a9
1 """Early stopping iterators
2
3 The idea here is to supply early-stopping heuristics that can be used in the
4 form:
5
6 stopper = SomeEarlyStopper()
7
8 for i in stopper():
9 # train from data
10 if i.set_score:
11 i.score = validation_score
12
13
14 So far I only have one heuristic, so maybe this won't scale.
15 """
16
17 class Stopper(object):
18
19 def train(self, data, update_rows_fn, update, validate, save=None):
20 """Return the best model trained on data
21
22 Parameters:
23 data - a thing that accepts getitem(<list of int64>), or a tuple of such things
24 update_rows_fn - fn : int --> <list or tensor of int>
25 update - fn: update an internal model from elements of data
26 validate - fn: evaluate an internal model based on elements of data
27 save - fn: return a copy of the internal model
28
29 The body of this function exhausts the <self> iterator, and trains a
30 model using early stopping in the process.
31 """
32
33 best = None
34 for stp in self:
35 i = stp.iter
36
37 # call update on some training set rows
38 t_rows = update_rows_fn(i)
39 if isinstance(data, (tuple, list)):
40 update(*[d[t_rows] for d in data])
41 else:
42 update(data[t_rows])
43
44 if stp.set_score:
45 stp.score = validate()
46 if (stp.score < stp.best_score) and save:
47 best = save()
48 return best
49
50
51 class ICML08Stopper(Stopper):
52 @staticmethod
53 def icml08(ntrain, batchsize):
54 """Some setting similar to what I used for ICML08 submission"""
55 #TODO: what did I actually use? put that in here.
56 return ICML08Stopper(30*ntrain/batchsize,
57 ntrain/batchsize, 0.96, 2.0, 100000000)
58
59 def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit):
60 self.initial_wait = i_wait
61 self.set_score_interval = v_int
62 self.min_improvement = min_improvement
63 self.patience = patience
64 self.hard_limit = hard_limit
65
66 self.best_score = float('inf')
67 self.best_iter = -1
68 self.iter = -1
69
70 self.set_score = False
71 self.score = None
72
73 def __iter__(self):
74 return self
75
76 E_set_score = 'when iter.set_score is True, caller must assign a score to iter.score'
77 def next(self):
78 if self.set_score: #left over from last time
79 if self.score is None:
80 raise Exception(ICML08Stopper.E_set_score)
81 if self.score < (self.best_score * self.min_improvement):
82 (self.best_score, self.best_iter) = (self.score, self.iter)
83 self.score = None #un-set it
84
85
86 starting = self.iter < self.initial_wait
87 waiting = self.iter < (self.patience * self.best_iter)
88 if starting or waiting:
89 # continue to iterate
90 self.iter += 1
91 if self.iter == self.hard_limit:
92 raise StopIteration
93 self.set_score = (self.iter % self.set_score_interval == 0)
94 return self
95
96 raise StopIteration
97
98
99 class NStages(ICML08Stopper):
100 """Run for a fixed number of steps, checking validation set every so
101 often."""
102 def __init__(self, hard_limit, v_int):
103 ICML08Stopper.__init__(self, hard_limit, v_int, 1.0, 1.0, hard_limit)
104
105 #TODO: could optimize next() function. Most of what's in ICML08Stopper.next()
106 #is not necessary
107
108