Mercurial > pylearn
annotate stopper.py @ 504:19ab9ce916e3
slightly more sophisticated system for finding the mnist data
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 29 Oct 2008 11:38:49 -0400 |
parents | 0ea793361d85 |
children |
rev | line source |
---|---|
178 | 1 """Early stopping iterators |
2 | |
3 The idea here is to supply early-stopping heuristics that can be used in the | |
4 form: | |
5 | |
6 stopper = SomeEarlyStopper() | |
7 | |
8 for i in stopper(): | |
9 # train from data | |
10 if i.set_score: | |
11 i.score = validation_score | |
12 | |
13 | |
14 So far I only have one heuristic, so maybe this won't scale. | |
15 """ | |
16 | |
17 class Stopper(object): | |
18 | |
19 def train(self, data, update_rows_fn, update, validate, save=None): | |
20 """Return the best model trained on data | |
21 | |
22 Parameters: | |
23 data - a thing that accepts getitem(<list of int64>), or a tuple of such things | |
24 update_rows_fn - fn : int --> <list or tensor of int> | |
25 update - fn: update an internal model from elements of data | |
26 validate - fn: evaluate an internal model based on elements of data | |
27 save - fn: return a copy of the internal model | |
28 | |
29 The body of this function exhausts the <self> iterator, and trains a | |
30 model using early stopping in the process. | |
31 """ | |
32 | |
33 best = None | |
34 for stp in self: | |
35 i = stp.iter | |
36 | |
37 # call update on some training set rows | |
38 t_rows = update_rows_fn(i) | |
39 if isinstance(data, (tuple, list)): | |
40 update(*[d[t_rows] for d in data]) | |
41 else: | |
42 update(data[t_rows]) | |
43 | |
44 if stp.set_score: | |
45 stp.score = validate() | |
46 if (stp.score < stp.best_score) and save: | |
47 best = save() | |
48 return best | |
49 | |
474
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
50 def find_min(self, step, check, save): |
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
51 best = None |
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
52 for stp in self: |
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
53 step() |
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
54 if stp.set_score: |
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
55 stp.score = check() |
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
56 if (stp.score < stp.best_score) and save: |
478
0ea793361d85
stopper.find_min returns tuple instead of just best
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
474
diff
changeset
|
57 best = (save(), stp.iter, stp.score) |
474
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
58 return best |
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
59 |
40c8a46b3da7
added Stopper.find_min
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
213
diff
changeset
|
60 |
178 | 61 |
62 class ICML08Stopper(Stopper): | |
63 @staticmethod | |
64 def icml08(ntrain, batchsize): | |
65 """Some setting similar to what I used for ICML08 submission""" | |
66 #TODO: what did I actually use? put that in here. | |
67 return ICML08Stopper(30*ntrain/batchsize, | |
68 ntrain/batchsize, 0.96, 2.0, 100000000) | |
69 | |
70 def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit): | |
71 self.initial_wait = i_wait | |
72 self.set_score_interval = v_int | |
73 self.min_improvement = min_improvement | |
74 self.patience = patience | |
75 self.hard_limit = hard_limit | |
76 | |
77 self.best_score = float('inf') | |
78 self.best_iter = -1 | |
79 self.iter = -1 | |
80 | |
81 self.set_score = False | |
82 self.score = None | |
83 | |
84 def __iter__(self): | |
85 return self | |
86 | |
87 E_set_score = 'when iter.set_score is True, caller must assign a score to iter.score' | |
88 def next(self): | |
211
bd728c83faff
in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
178
diff
changeset
|
89 |
213
cc96f93a7810
previous commit was supposed to concern only one file, dataset.py, try to undo my other changes with this commit (nothing was broken though, just useless debugging prints)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
211
diff
changeset
|
90 #print 'ICML08 stopper, were doing a next' |
211
bd728c83faff
in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
178
diff
changeset
|
91 |
178 | 92 if self.set_score: #left over from last time |
93 if self.score is None: | |
94 raise Exception(ICML08Stopper.E_set_score) | |
95 if self.score < (self.best_score * self.min_improvement): | |
96 (self.best_score, self.best_iter) = (self.score, self.iter) | |
97 self.score = None #un-set it | |
98 | |
99 | |
100 starting = self.iter < self.initial_wait | |
101 waiting = self.iter < (self.patience * self.best_iter) | |
102 if starting or waiting: | |
103 # continue to iterate | |
104 self.iter += 1 | |
105 if self.iter == self.hard_limit: | |
106 raise StopIteration | |
107 self.set_score = (self.iter % self.set_score_interval == 0) | |
108 return self | |
109 | |
110 raise StopIteration | |
111 | |
112 | |
113 class NStages(ICML08Stopper): | |
114 """Run for a fixed number of steps, checking validation set every so | |
115 often.""" | |
116 def __init__(self, hard_limit, v_int): | |
117 ICML08Stopper.__init__(self, hard_limit, v_int, 1.0, 1.0, hard_limit) | |
118 | |
119 #TODO: could optimize next() function. Most of what's in ICML08Stopper.next() | |
120 #is not necessary | |
121 | |
122 |