annotate doc/v2_planning/plugin_PL.py @ 1332:837768915081

added test idea to test_mcRBM
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 18 Oct 2010 08:53:08 -0400
parents 826d78f0135f
children
rev   line source
1253
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
1 class RBM(Model):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
2 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
3 Restricted Boltzmann Machine.
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
4 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
5 def __init__(self, n_visible, n_hidden, visible = None, name = None):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
6
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
7 if name is None:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
8 self.__name__ = self.__class__.__name__
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
9 else:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
10 self.__name__ = name
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
11 self.n_visible = n_visible
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
12 self.n_hidden = n_hidden
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
13
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
14 if self.visible is None:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
15 self.visible = theano.tensor.matrix([name='.'.join(self.__name__, 'visible']))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
16 else:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
17 self.visible = visible
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
18
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
19 self.W = theano.shared(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
20 numpy.zeros((n_visible, n_hidden), dtype=theano.config.floatX),
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
21 name=self.__name__ + '.W')
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
22 self.b_hid = theano.shared(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
23 numpy.zeros((n_hidden,), dtype=theano.config.floatX),
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
24 name=self.__name__ + '.b_hid')
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
25 self.b_vis = theano.shared(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
26 numpy.zeros((n_hidden,), dtype=theano.config.floatX),
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
27 name=self.__name__ + '.b_vis')
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
28
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
29 self.inputs = [self.visible]
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
30 self.targets = []
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
31 self.parameters = [self.W, self.b_hid, self.b_vis]
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
32 self.outputs = ...
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
33 self.cost = None
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
34 self.gradients = [...]
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
35
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
36 class LogisticRegression(Model):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
37 pass
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
38
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
39 class GradientBasedLearner(Learner):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
40 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
41 Learner that uses a gradient-base Optimizer to train a Model.
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
42 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
43 def __init__(self, model, optimizer, name = None):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
44 self.model = model
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
45 self.optimizer = optimizer
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
46 ...
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
47
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
48 self.updates = optimizer.iterative_optimizer(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
49 parameters = model.parameters,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
50 cost = model.cost)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
51
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
52 # TODO: not sure of how to interface data set with the function
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
53 self.train_fn = theano.function(model.inputs+model.targets,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
54 model.cost, updates=self.updates)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
55
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
56 def use_dataset(self, dataset):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
57 self.train_set = dataset
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
58
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
59 # The decorator indicates that this function will declare some hooks.
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
60 # More hooks could be automatically declared, for instance
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
61 # 'begin_function' and 'end_function'
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
62 @declare_hooks(['begin_train_iter', 'end_train_iter'])
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
63 def adapt(self, n_steps=1):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
64 for i in xrange(n_steps):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
65 self.adapt.hooks.execute(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
66 'begin_train_iter',
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
67 context = dict(iter=i, total=n_steps, locals=locals()))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
68
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
69 data = self.train_set.next()
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
70 self.train_fn(data)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
71
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
72 self.adapt.hooks.execute(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
73 'end_train_iter',
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
74 context = dict(iter = i, locals=locals()))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
75
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
76
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
77 class SGD(Optimizer):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
78 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
79 Stochastic gradient descent with fixed learning rate.
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
80 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
81 def __init__(self, step_size):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
82 self.step_size = step_size
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
83
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
84 def iterative_optimizer(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
85 parameters,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
86 cost=None,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
87 gradients=None,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
88 stop=None,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
89 updates=None,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
90 ):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
91
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
92 if updates is not None:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
93 ret = updates
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
94 else:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
95 ret = {}
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
96
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
97 if gradients is None:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
98 if cost is None:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
99 raise SomeError('SGD needs to be provided either a cost or a gradients list')
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
100 gradients = theano.tensor.grad(cost, parameters)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
101
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
102 for p, g in izip(parameters, gradients):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
103 if p in updates:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
104 raise KeyError('Parameter %s already has an update value (%s)' % (p, g))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
105 ret[p] = p - self.step_size * g
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
106
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
107 # never stop
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
108 if stop is not None:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
109 ret[stop] = False
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
110
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
111 return ret
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
112
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
113
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
114
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
115 class DBN(Learner):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
116 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
117 Deep Belief Network.
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
118 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
119 def __init__(self, n_layers, layer_config, n_ft_steps, ft_step_size):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
120 # Layers are GradientBasedLearners, with DBN as Model
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
121 self.layers = []
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
122 # Pretraining cumulative schedule
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
123 self.pt_cumul_schedule = [0]
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
124 # Build the layers and the fully-connected model
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
125 self.input = theano.tensor.matrix(name='.'.join([self.__name__, 'input']))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
126 self.output = self.input
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
127 self.ft_params = []
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
128 for i,lconf in enumerate(layer_config):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
129 rbm = RBM(visible = self.output, ...)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
130 self.output = rbm.hidden_expectation
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
131 layer = GradientBasedLearner(...)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
132 self.layers.append(layer)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
133
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
134 self.pt_cumul_schedule.append(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
135 self.pt_cumul_schedule[-1]+lc.n_pretrain_steps)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
136
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
137 self.ft_params.extend([rbm.W, rbm.b_hid])
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
138
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
139 # Build the fine-tunable model
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
140 self.target = theano.tensor.ivector(name='.'.join([self.__name__, 'target']))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
141 logreg = LogisticRegression(...)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
142 self.output = logreg.output
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
143 self.cost = logreg.nll
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
144 self.ft_params.extend([logreg.W, logreg.b])
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
145
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
146 ft_optimizer = SGD(ft_step_size)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
147 self.ft_updates = ft_optimizer.iterative_optimizer(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
148 parameters = self.ft_params,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
149 cost = self.cost)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
150 self.ft_fn = theano.function(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
151 [self.input, self.target],
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
152 self.cost,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
153 updates = self.ft_updates,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
154 name='.'.join([self.__name__, 'ft_fn']))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
155
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
156 self.stage = 0
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
157
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
158 @declare_hooks([
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
159 'begin_pretrain_layer','end_pretrain_layer',
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
160 'begin_finetune_iter', 'end_finetune_iter'])
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
161
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
162 def adapt(self, n_steps=1):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
163 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
164 Each "step" is accomplished by the corresponding Learner (either an RBM,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
165 or the global NNet).
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
166 '''
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
167 train_x, train_y = self.dataset
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
168 n_remaining_steps = n_steps
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
169 # Unsupervised pre-training
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
170 for i, layer in ienumerate(self.layers):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
171 if (self.pt_cumul_schedule[i] <= self.stage
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
172 and self.stage < self.pt_cumul_schedule[i+1]):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
173
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
174 self.adapt.hooks.execute(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
175 'begin_pretrain_layer',
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
176 context = dict(iter=i, total=len(self.layers), locals=locals()))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
177
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
178 n_pt_steps = min(n_remaining_steps, self.pt_cumul_schedule[i+1] - self.stage)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
179 layer.use_dataset(train_x)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
180 layer.adapt(n_steps = n_pt_steps)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
181 self.stage += self.n_pt_steps
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
182 n_remaining_steps -= n_pt_steps
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
183
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
184 self.adapt.hooks.execute(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
185 'end_pretrain_layer',
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
186 context = dict(iter=i, total=len(self.layers), locals=locals()))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
187
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
188 # For the next layer, the data needs to be preprocessed
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
189 train_x = layer.compute_Eh_given_v(train_x) # or just compute_output?
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
190
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
191 # Supervised fine-tuning
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
192 if n_remaining_steps > 0:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
193 sup_data = train_x, train_y
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
194 for i in xrange(n_remaining_steps):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
195 self.adapt.hooks.execute(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
196 'begin_train_iter',
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
197 context = dict(iter=i, total=n_steps, locals=locals()))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
198
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
199 data = self.train_set.next()
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
200 self.ft_fn(data)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
201
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
202 self.adapt.hooks.execute(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
203 'end_train_iter',
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
204 context = dict(iter = i, locals=locals()))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
205
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
206
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
207 ## TODO: implement k-fold cross-validation
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
208
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
209 class Hooks:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
210 def __init__(self):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
211 # The DB consists in a dictionary,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
212 # the keys are the hooks' names (as strings),
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
213 # the values are lists of (function, exec_condition) pairs of functions
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
214 self.db = {}
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
215
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
216 def declare(self, name):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
217 if name in self.db:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
218 raise KeyError('Hook "%s" is already declared' % name)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
219 self.db[name] = []
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
220
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
221 def execute(self, name, context):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
222 if name not in self.db:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
223 raise KeyError('Hook "%s" does not exist', % name)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
224 #TODO: add contextual information to context, like current time, time of last call,...
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
225 for fn, cond in self.db[name]:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
226 if cond(**context):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
227 fn(**context)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
228
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
229 def register(self, name, function, exec_condition):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
230 if name not in self.db:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
231 raise KeyError('Hook "%s" does not exist', % name)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
232 self.db[name].append((function, exec_condition))
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
233
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
234 #TODO: add __getattr__ to have more intuitive access to the hooks
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
235
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
236 # Hook declaration mechanism
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
237 def declare_hooks(hooks_list):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
238 def deco(f):
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
239 f.hooks = Hooks()
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
240 for hook_name in hooks_list:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
241 f.hooks.declare(hook_name)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
242
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
243 return deco
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
244
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
245 # Conditions
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
246 def always():
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
247 return lambda *args, **kwargs: True
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
248
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
249
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
250 def main():
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
251 train_data = MNIST.gettrain()
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
252 test_data = MNIST.gettest()
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
253
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
254 train_x, train_y = train_data
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
255
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
256 preprocessor = PCA(ndim = 80)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
257 preprocessor.train(train_x)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
258
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
259 preprocessed_x = preprocessor.compute_output(train_x)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
260 ## for more robustess, we can have something like:
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
261 # peprocessed_x = ProcessDataSet(orig_data = train_x, function = preprocessor.compute_output)
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
262
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
263
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
264 x = matrix()
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
265 dbn = DBN(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
266 input = x,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
267 n_layers = 3,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
268 layer_config = [dict(n_hidden = 500, n_unsup_steps=1000)] * 3
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
269 )
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
270
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
271 dbn.layers[0].adapt.hooks.register(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
272 'begin_train_iter',
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
273 function = ...,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
274 exec_cond = always()
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
275 )
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
276
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
277 dbn.layers[0].adapt.hooks.register(
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
278 'end_train_iter',
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
279 function = ...,
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
280 exec_cond = lambda iter, **kwargs: iter%20==0
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
281 )
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
282
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
283
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
284 if __name__ == '__main__':
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
285 main()
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
286
826d78f0135f Prototype for "hooks" simpler than full control-flow rewrite.
Pascal Lamblin <lamblinp@iro.umontreal.ca>
parents:
diff changeset
287