Mercurial > pylearn
annotate doc/v2_planning/plugin_PL.py @ 1363:18b2ebec6bca
Reply to a comment of OD
author | Razvan Pascanu <r.pascanu@gmail.com> |
---|---|
date | Fri, 12 Nov 2010 11:11:49 -0500 |
parents | 826d78f0135f |
children |
rev | line source |
---|---|
1253
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
1 class RBM(Model): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
2 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
3 Restricted Boltzmann Machine. |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
4 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
5 def __init__(self, n_visible, n_hidden, visible = None, name = None): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
6 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
7 if name is None: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
8 self.__name__ = self.__class__.__name__ |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
9 else: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
10 self.__name__ = name |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
11 self.n_visible = n_visible |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
12 self.n_hidden = n_hidden |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
13 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
14 if self.visible is None: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
15 self.visible = theano.tensor.matrix([name='.'.join(self.__name__, 'visible'])) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
16 else: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
17 self.visible = visible |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
18 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
19 self.W = theano.shared( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
20 numpy.zeros((n_visible, n_hidden), dtype=theano.config.floatX), |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
21 name=self.__name__ + '.W') |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
22 self.b_hid = theano.shared( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
23 numpy.zeros((n_hidden,), dtype=theano.config.floatX), |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
24 name=self.__name__ + '.b_hid') |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
25 self.b_vis = theano.shared( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
26 numpy.zeros((n_hidden,), dtype=theano.config.floatX), |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
27 name=self.__name__ + '.b_vis') |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
28 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
29 self.inputs = [self.visible] |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
30 self.targets = [] |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
31 self.parameters = [self.W, self.b_hid, self.b_vis] |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
32 self.outputs = ... |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
33 self.cost = None |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
34 self.gradients = [...] |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
35 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
36 class LogisticRegression(Model): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
37 pass |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
38 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
39 class GradientBasedLearner(Learner): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
40 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
41 Learner that uses a gradient-base Optimizer to train a Model. |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
42 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
43 def __init__(self, model, optimizer, name = None): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
44 self.model = model |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
45 self.optimizer = optimizer |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
46 ... |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
47 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
48 self.updates = optimizer.iterative_optimizer( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
49 parameters = model.parameters, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
50 cost = model.cost) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
51 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
52 # TODO: not sure of how to interface data set with the function |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
53 self.train_fn = theano.function(model.inputs+model.targets, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
54 model.cost, updates=self.updates) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
55 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
56 def use_dataset(self, dataset): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
57 self.train_set = dataset |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
58 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
59 # The decorator indicates that this function will declare some hooks. |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
60 # More hooks could be automatically declared, for instance |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
61 # 'begin_function' and 'end_function' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
62 @declare_hooks(['begin_train_iter', 'end_train_iter']) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
63 def adapt(self, n_steps=1): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
64 for i in xrange(n_steps): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
65 self.adapt.hooks.execute( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
66 'begin_train_iter', |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
67 context = dict(iter=i, total=n_steps, locals=locals())) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
68 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
69 data = self.train_set.next() |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
70 self.train_fn(data) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
71 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
72 self.adapt.hooks.execute( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
73 'end_train_iter', |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
74 context = dict(iter = i, locals=locals())) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
75 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
76 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
77 class SGD(Optimizer): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
78 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
79 Stochastic gradient descent with fixed learning rate. |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
80 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
81 def __init__(self, step_size): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
82 self.step_size = step_size |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
83 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
84 def iterative_optimizer( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
85 parameters, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
86 cost=None, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
87 gradients=None, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
88 stop=None, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
89 updates=None, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
90 ): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
91 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
92 if updates is not None: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
93 ret = updates |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
94 else: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
95 ret = {} |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
96 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
97 if gradients is None: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
98 if cost is None: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
99 raise SomeError('SGD needs to be provided either a cost or a gradients list') |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
100 gradients = theano.tensor.grad(cost, parameters) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
101 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
102 for p, g in izip(parameters, gradients): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
103 if p in updates: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
104 raise KeyError('Parameter %s already has an update value (%s)' % (p, g)) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
105 ret[p] = p - self.step_size * g |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
106 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
107 # never stop |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
108 if stop is not None: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
109 ret[stop] = False |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
110 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
111 return ret |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
112 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
113 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
114 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
115 class DBN(Learner): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
116 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
117 Deep Belief Network. |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
118 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
119 def __init__(self, n_layers, layer_config, n_ft_steps, ft_step_size): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
120 # Layers are GradientBasedLearners, with DBN as Model |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
121 self.layers = [] |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
122 # Pretraining cumulative schedule |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
123 self.pt_cumul_schedule = [0] |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
124 # Build the layers and the fully-connected model |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
125 self.input = theano.tensor.matrix(name='.'.join([self.__name__, 'input'])) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
126 self.output = self.input |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
127 self.ft_params = [] |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
128 for i,lconf in enumerate(layer_config): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
129 rbm = RBM(visible = self.output, ...) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
130 self.output = rbm.hidden_expectation |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
131 layer = GradientBasedLearner(...) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
132 self.layers.append(layer) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
133 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
134 self.pt_cumul_schedule.append( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
135 self.pt_cumul_schedule[-1]+lc.n_pretrain_steps) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
136 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
137 self.ft_params.extend([rbm.W, rbm.b_hid]) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
138 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
139 # Build the fine-tunable model |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
140 self.target = theano.tensor.ivector(name='.'.join([self.__name__, 'target'])) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
141 logreg = LogisticRegression(...) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
142 self.output = logreg.output |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
143 self.cost = logreg.nll |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
144 self.ft_params.extend([logreg.W, logreg.b]) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
145 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
146 ft_optimizer = SGD(ft_step_size) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
147 self.ft_updates = ft_optimizer.iterative_optimizer( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
148 parameters = self.ft_params, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
149 cost = self.cost) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
150 self.ft_fn = theano.function( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
151 [self.input, self.target], |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
152 self.cost, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
153 updates = self.ft_updates, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
154 name='.'.join([self.__name__, 'ft_fn'])) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
155 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
156 self.stage = 0 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
157 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
158 @declare_hooks([ |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
159 'begin_pretrain_layer','end_pretrain_layer', |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
160 'begin_finetune_iter', 'end_finetune_iter']) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
161 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
162 def adapt(self, n_steps=1): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
163 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
164 Each "step" is accomplished by the corresponding Learner (either an RBM, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
165 or the global NNet). |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
166 ''' |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
167 train_x, train_y = self.dataset |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
168 n_remaining_steps = n_steps |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
169 # Unsupervised pre-training |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
170 for i, layer in ienumerate(self.layers): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
171 if (self.pt_cumul_schedule[i] <= self.stage |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
172 and self.stage < self.pt_cumul_schedule[i+1]): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
173 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
174 self.adapt.hooks.execute( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
175 'begin_pretrain_layer', |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
176 context = dict(iter=i, total=len(self.layers), locals=locals())) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
177 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
178 n_pt_steps = min(n_remaining_steps, self.pt_cumul_schedule[i+1] - self.stage) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
179 layer.use_dataset(train_x) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
180 layer.adapt(n_steps = n_pt_steps) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
181 self.stage += self.n_pt_steps |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
182 n_remaining_steps -= n_pt_steps |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
183 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
184 self.adapt.hooks.execute( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
185 'end_pretrain_layer', |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
186 context = dict(iter=i, total=len(self.layers), locals=locals())) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
187 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
188 # For the next layer, the data needs to be preprocessed |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
189 train_x = layer.compute_Eh_given_v(train_x) # or just compute_output? |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
190 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
191 # Supervised fine-tuning |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
192 if n_remaining_steps > 0: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
193 sup_data = train_x, train_y |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
194 for i in xrange(n_remaining_steps): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
195 self.adapt.hooks.execute( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
196 'begin_train_iter', |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
197 context = dict(iter=i, total=n_steps, locals=locals())) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
198 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
199 data = self.train_set.next() |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
200 self.ft_fn(data) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
201 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
202 self.adapt.hooks.execute( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
203 'end_train_iter', |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
204 context = dict(iter = i, locals=locals())) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
205 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
206 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
207 ## TODO: implement k-fold cross-validation |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
208 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
209 class Hooks: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
210 def __init__(self): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
211 # The DB consists in a dictionary, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
212 # the keys are the hooks' names (as strings), |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
213 # the values are lists of (function, exec_condition) pairs of functions |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
214 self.db = {} |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
215 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
216 def declare(self, name): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
217 if name in self.db: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
218 raise KeyError('Hook "%s" is already declared' % name) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
219 self.db[name] = [] |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
220 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
221 def execute(self, name, context): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
222 if name not in self.db: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
223 raise KeyError('Hook "%s" does not exist', % name) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
224 #TODO: add contextual information to context, like current time, time of last call,... |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
225 for fn, cond in self.db[name]: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
226 if cond(**context): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
227 fn(**context) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
228 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
229 def register(self, name, function, exec_condition): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
230 if name not in self.db: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
231 raise KeyError('Hook "%s" does not exist', % name) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
232 self.db[name].append((function, exec_condition)) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
233 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
234 #TODO: add __getattr__ to have more intuitive access to the hooks |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
235 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
236 # Hook declaration mechanism |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
237 def declare_hooks(hooks_list): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
238 def deco(f): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
239 f.hooks = Hooks() |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
240 for hook_name in hooks_list: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
241 f.hooks.declare(hook_name) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
242 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
243 return deco |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
244 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
245 # Conditions |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
246 def always(): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
247 return lambda *args, **kwargs: True |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
248 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
249 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
250 def main(): |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
251 train_data = MNIST.gettrain() |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
252 test_data = MNIST.gettest() |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
253 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
254 train_x, train_y = train_data |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
255 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
256 preprocessor = PCA(ndim = 80) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
257 preprocessor.train(train_x) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
258 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
259 preprocessed_x = preprocessor.compute_output(train_x) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
260 ## for more robustess, we can have something like: |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
261 # peprocessed_x = ProcessDataSet(orig_data = train_x, function = preprocessor.compute_output) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
262 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
263 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
264 x = matrix() |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
265 dbn = DBN( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
266 input = x, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
267 n_layers = 3, |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
268 layer_config = [dict(n_hidden = 500, n_unsup_steps=1000)] * 3 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
269 ) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
270 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
271 dbn.layers[0].adapt.hooks.register( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
272 'begin_train_iter', |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
273 function = ..., |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
274 exec_cond = always() |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
275 ) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
276 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
277 dbn.layers[0].adapt.hooks.register( |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
278 'end_train_iter', |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
279 function = ..., |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
280 exec_cond = lambda iter, **kwargs: iter%20==0 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
281 ) |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
282 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
283 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
284 if __name__ == '__main__': |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
285 main() |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
286 |
826d78f0135f
Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff
changeset
|
287 |