Mercurial > pylearn
comparison stopper.py @ 178:4090779e39a9
merged
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 13 May 2008 15:12:20 -0400 |
parents | |
children | bd728c83faff |
comparison
equal
deleted
inserted
replaced
177:69759976b3ac | 178:4090779e39a9 |
---|---|
1 """Early stopping iterators | |
2 | |
3 The idea here is to supply early-stopping heuristics that can be used in the | |
4 form: | |
5 | |
6 stopper = SomeEarlyStopper() | |
7 | |
8 for i in stopper(): | |
9 # train from data | |
10 if i.set_score: | |
11 i.score = validation_score | |
12 | |
13 | |
14 So far I only have one heuristic, so maybe this won't scale. | |
15 """ | |
16 | |
17 class Stopper(object): | |
18 | |
19 def train(self, data, update_rows_fn, update, validate, save=None): | |
20 """Return the best model trained on data | |
21 | |
22 Parameters: | |
23 data - a thing that accepts getitem(<list of int64>), or a tuple of such things | |
24 update_rows_fn - fn : int --> <list or tensor of int> | |
25 update - fn: update an internal model from elements of data | |
26 validate - fn: evaluate an internal model based on elements of data | |
27 save - fn: return a copy of the internal model | |
28 | |
29 The body of this function exhausts the <self> iterator, and trains a | |
30 model using early stopping in the process. | |
31 """ | |
32 | |
33 best = None | |
34 for stp in self: | |
35 i = stp.iter | |
36 | |
37 # call update on some training set rows | |
38 t_rows = update_rows_fn(i) | |
39 if isinstance(data, (tuple, list)): | |
40 update(*[d[t_rows] for d in data]) | |
41 else: | |
42 update(data[t_rows]) | |
43 | |
44 if stp.set_score: | |
45 stp.score = validate() | |
46 if (stp.score < stp.best_score) and save: | |
47 best = save() | |
48 return best | |
49 | |
50 | |
51 class ICML08Stopper(Stopper): | |
52 @staticmethod | |
53 def icml08(ntrain, batchsize): | |
54 """Some setting similar to what I used for ICML08 submission""" | |
55 #TODO: what did I actually use? put that in here. | |
56 return ICML08Stopper(30*ntrain/batchsize, | |
57 ntrain/batchsize, 0.96, 2.0, 100000000) | |
58 | |
59 def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit): | |
60 self.initial_wait = i_wait | |
61 self.set_score_interval = v_int | |
62 self.min_improvement = min_improvement | |
63 self.patience = patience | |
64 self.hard_limit = hard_limit | |
65 | |
66 self.best_score = float('inf') | |
67 self.best_iter = -1 | |
68 self.iter = -1 | |
69 | |
70 self.set_score = False | |
71 self.score = None | |
72 | |
73 def __iter__(self): | |
74 return self | |
75 | |
76 E_set_score = 'when iter.set_score is True, caller must assign a score to iter.score' | |
77 def next(self): | |
78 if self.set_score: #left over from last time | |
79 if self.score is None: | |
80 raise Exception(ICML08Stopper.E_set_score) | |
81 if self.score < (self.best_score * self.min_improvement): | |
82 (self.best_score, self.best_iter) = (self.score, self.iter) | |
83 self.score = None #un-set it | |
84 | |
85 | |
86 starting = self.iter < self.initial_wait | |
87 waiting = self.iter < (self.patience * self.best_iter) | |
88 if starting or waiting: | |
89 # continue to iterate | |
90 self.iter += 1 | |
91 if self.iter == self.hard_limit: | |
92 raise StopIteration | |
93 self.set_score = (self.iter % self.set_score_interval == 0) | |
94 return self | |
95 | |
96 raise StopIteration | |
97 | |
98 | |
99 class NStages(ICML08Stopper): | |
100 """Run for a fixed number of steps, checking validation set every so | |
101 often.""" | |
102 def __init__(self, hard_limit, v_int): | |
103 ICML08Stopper.__init__(self, hard_limit, v_int, 1.0, 1.0, hard_limit) | |
104 | |
105 #TODO: could optimize next() function. Most of what's in ICML08Stopper.next() | |
106 #is not necessary | |
107 | |
108 |