Mercurial > ift6266
annotate code_tutoriel/deep.py @ 338:fca22114bb23
added async save, restart from old model and independant error calculation based on Arnaud's iterator
author | xaviermuller |
---|---|
date | Sat, 17 Apr 2010 12:42:48 -0400 |
parents | 4bc5eeec6394 |
children |
rev | line source |
---|---|
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
1 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
2 Draft of DBN, DAA, SDAA, RBM tutorial code |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
3 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
4 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
5 import sys |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
6 import numpy |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
7 import theano |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
8 import time |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
9 import theano.tensor as T |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
10 from theano.tensor.shared_randomstreams import RandomStreams |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
11 from theano import shared, function |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
12 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
13 import gzip |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
14 import cPickle |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
15 import pylearn.io.image_tiling |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
16 import PIL |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
17 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
18 # NNET STUFF |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
19 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
20 class LogisticRegression(object): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
21 """Multi-class Logistic Regression Class |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
22 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
23 The logistic regression is fully described by a weight matrix :math:`W` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
24 and bias vector :math:`b`. Classification is done by projecting data |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
25 points onto a set of hyperplanes, the distance to which is used to |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
26 determine a class membership probability. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
27 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
28 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
29 def __init__(self, input, n_in, n_out): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
30 """ Initialize the parameters of the logistic regression |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
31 :param input: symbolic variable that describes the input of the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
32 architecture (one minibatch) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
33 :type n_in: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
34 :param n_in: number of input units, the dimension of the space in |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
35 which the datapoints lie |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
36 :type n_out: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
37 :param n_out: number of output units, the dimension of the space in |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
38 which the labels lie |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
39 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
40 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
41 # initialize with 0 the weights W as a matrix of shape (n_in, n_out) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
42 self.W = theano.shared( value=numpy.zeros((n_in,n_out), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
43 dtype = theano.config.floatX) ) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
44 # initialize the baises b as a vector of n_out 0s |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
45 self.b = theano.shared( value=numpy.zeros((n_out,), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
46 dtype = theano.config.floatX) ) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
47 # compute vector of class-membership probabilities in symbolic form |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
48 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W)+self.b) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
49 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
50 # compute prediction as class whose probability is maximal in |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
51 # symbolic form |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
52 self.y_pred=T.argmax(self.p_y_given_x, axis=1) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
53 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
54 # list of parameters for this layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
55 self.params = [self.W, self.b] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
56 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
57 def negative_log_likelihood(self, y): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
58 """Return the mean of the negative log-likelihood of the prediction |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
59 of this model under a given target distribution. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
60 :param y: corresponds to a vector that gives for each example the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
61 correct label |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
62 Note: we use the mean instead of the sum so that |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
63 the learning rate is less dependent on the batch size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
64 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
65 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y]) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
66 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
67 def errors(self, y): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
68 """Return a float representing the number of errors in the minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
69 over the total number of examples of the minibatch ; zero one |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
70 loss over the size of the minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
71 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
72 # check if y has same dimension of y_pred |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
73 if y.ndim != self.y_pred.ndim: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
74 raise TypeError('y should have the same shape as self.y_pred', |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
75 ('y', target.type, 'y_pred', self.y_pred.type)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
76 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
77 # check if y is of the correct datatype |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
78 if y.dtype.startswith('int'): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
79 # the T.neq operator returns a vector of 0s and 1s, where 1 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
80 # represents a mistake in prediction |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
81 return T.mean(T.neq(self.y_pred, y)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
82 else: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
83 raise NotImplementedError() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
84 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
85 class SigmoidalLayer(object): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
86 def __init__(self, rng, input, n_in, n_out): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
87 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
88 Typical hidden layer of a MLP: units are fully-connected and have |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
89 sigmoidal activation function. Weight matrix W is of shape (n_in,n_out) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
90 and the bias vector b is of shape (n_out,). |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
91 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
92 Hidden unit activation is given by: sigmoid(dot(input,W) + b) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
93 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
94 :type rng: numpy.random.RandomState |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
95 :param rng: a random number generator used to initialize weights |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
96 :type input: theano.tensor.matrix |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
97 :param input: a symbolic tensor of shape (n_examples, n_in) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
98 :type n_in: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
99 :param n_in: dimensionality of input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
100 :type n_out: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
101 :param n_out: number of hidden units |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
102 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
103 self.input = input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
104 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
105 W_values = numpy.asarray( rng.uniform( \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
106 low = -numpy.sqrt(6./(n_in+n_out)), \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
107 high = numpy.sqrt(6./(n_in+n_out)), \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
108 size = (n_in, n_out)), dtype = theano.config.floatX) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
109 self.W = theano.shared(value = W_values) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
110 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
111 b_values = numpy.zeros((n_out,), dtype= theano.config.floatX) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
112 self.b = theano.shared(value= b_values) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
113 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
114 self.output = T.nnet.sigmoid(T.dot(input, self.W) + self.b) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
115 self.params = [self.W, self.b] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
116 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
117 # PRETRAINING LAYERS |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
118 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
119 class RBM(object): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
120 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
121 *** WRITE THE ENERGY FUNCTION USE SAME LETTERS AS VARIABLE NAMES IN CODE |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
122 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
123 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
124 def __init__(self, input=None, n_visible=None, n_hidden=None, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
125 W=None, hbias=None, vbias=None, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
126 numpy_rng=None, theano_rng=None): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
127 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
128 RBM constructor. Defines the parameters of the model along with |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
129 basic operations for inferring hidden from visible (and vice-versa), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
130 as well as for performing CD updates. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
131 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
132 :param input: None for standalone RBMs or symbolic variable if RBM is |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
133 part of a larger graph. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
134 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
135 :param n_visible: number of visible units (necessary when W or vbias is None) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
136 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
137 :param n_hidden: number of hidden units (necessary when W or hbias is None) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
138 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
139 :param W: weights to use for the RBM. None means that a shared variable will be |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
140 created with a randomly chosen matrix of size (n_visible, n_hidden). |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
141 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
142 :param hbias: *** |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
143 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
144 :param vbias: *** |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
145 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
146 :param numpy_rng: random number generator (necessary when W is None) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
147 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
148 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
149 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
150 params = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
151 if W is None: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
152 # choose initial values for weight matrix of RBM |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
153 initial_W = numpy.asarray( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
154 numpy_rng.uniform( \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
155 low=-numpy.sqrt(6./(n_hidden+n_visible)), \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
156 high=numpy.sqrt(6./(n_hidden+n_visible)), \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
157 size=(n_visible, n_hidden)), \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
158 dtype=theano.config.floatX) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
159 W = theano.shared(value=initial_W, name='W') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
160 params.append(W) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
161 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
162 if hbias is None: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
163 # theano shared variables for hidden biases |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
164 hbias = theano.shared(value=numpy.zeros(n_hidden, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
165 dtype=theano.config.floatX), name='hbias') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
166 params.append(hbias) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
167 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
168 if vbias is None: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
169 # theano shared variables for visible biases |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
170 vbias = theano.shared(value=numpy.zeros(n_visible, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
171 dtype=theano.config.floatX), name='vbias') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
172 params.append(vbias) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
173 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
174 if input is None: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
175 # initialize input layer for standalone RBM or layer0 of DBN |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
176 input = T.matrix('input') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
177 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
178 # setup theano random number generator |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
179 if theano_rng is None: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
180 theano_rng = RandomStreams(numpy_rng.randint(2**30)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
181 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
182 self.visible = self.input = input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
183 self.W = W |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
184 self.hbias = hbias |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
185 self.vbias = vbias |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
186 self.theano_rng = theano_rng |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
187 self.params = params |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
188 self.hidden_mean = T.nnet.sigmoid(T.dot(input, W)+hbias) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
189 self.hidden_sample = theano_rng.binomial(self.hidden_mean.shape, 1, self.hidden_mean) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
190 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
191 def gibbs_k(self, v_sample, k): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
192 ''' This function implements k steps of Gibbs sampling ''' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
193 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
194 # We compute the visible after k steps of Gibbs by iterating |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
195 # over ``gibs_1`` for k times; this can be done in Theano using |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
196 # the `scan op`. For a more comprehensive description of scan see |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
197 # http://deeplearning.net/software/theano/library/scan.html . |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
198 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
199 def gibbs_1(v0_sample, W, hbias, vbias): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
200 ''' This function implements one Gibbs step ''' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
201 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
202 # compute the activation of the hidden units given a sample of the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
203 # vissibles |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
204 h0_mean = T.nnet.sigmoid(T.dot(v0_sample, W) + hbias) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
205 # get a sample of the hiddens given their activation |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
206 h0_sample = self.theano_rng.binomial(h0_mean.shape, 1, h0_mean) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
207 # compute the activation of the visible given the hidden sample |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
208 v1_mean = T.nnet.sigmoid(T.dot(h0_sample, W.T) + vbias) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
209 # get a sample of the visible given their activation |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
210 v1_act = self.theano_rng.binomial(v1_mean.shape, 1, v1_mean) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
211 return [v1_mean, v1_act] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
212 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
213 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
214 # DEBUGGING TO DO ALL WITHOUT SCAN |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
215 if k == 1: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
216 return gibbs_1(v_sample, self.W, self.hbias, self.vbias) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
217 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
218 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
219 # Because we require as output two values, namely the mean field |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
220 # approximation of the visible and the sample obtained after k steps, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
221 # scan needs to know the shape of those two outputs. Scan takes |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
222 # this information from the variables containing the initial state |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
223 # of the outputs. Since we do not need a initial state of ``v_mean`` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
224 # we provide a dummy one used only to get the correct shape |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
225 v_mean = T.zeros_like(v_sample) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
226 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
227 # ``outputs_taps`` is an argument of scan which describes at each |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
228 # time step what past values of the outputs the function applied |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
229 # recursively needs. This is given in the form of a dictionary, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
230 # where the keys are outputs indexes, and values are a list of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
231 # of the offsets used by the corresponding outputs |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
232 # In our case the function ``gibbs_1`` applied recursively, requires |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
233 # at time k the past value k-1 for the first output (index 0) and |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
234 # no past value of the second output |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
235 outputs_taps = { 0 : [-1], 1 : [] } |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
236 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
237 v_means, v_samples = theano.scan( fn = gibbs_1, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
238 sequences = [], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
239 initial_states = [v_sample, v_mean], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
240 non_sequences = [self.W, self.hbias, self.vbias], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
241 outputs_taps = outputs_taps, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
242 n_steps = k) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
243 return v_means[-1], v_samples[-1] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
244 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
245 def free_energy(self, v_sample): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
246 wx_b = T.dot(v_sample, self.W) + self.hbias |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
247 vbias_term = T.sum(T.dot(v_sample, self.vbias)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
248 hidden_term = T.sum(T.log(1+T.exp(wx_b))) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
249 return -hidden_term - vbias_term |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
250 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
251 def cd(self, visible = None, persistent = None, steps = 1): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
252 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
253 Return a 5-tuple of values related to contrastive divergence: (cost, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
254 end-state of negative-phase chain, gradient on weights, gradient on |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
255 hidden bias, gradient on visible bias) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
256 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
257 If visible is None, it defaults to self.input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
258 If persistent is None, it defaults to self.input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
259 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
260 CD aka CD1 - cd() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
261 CD-10 - cd(steps=10) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
262 PCD - cd(persistent=shared(numpy.asarray(initializer))) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
263 PCD-k - cd(persistent=shared(numpy.asarray(initializer)), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
264 steps=10) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
265 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
266 if visible is None: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
267 visible = self.input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
268 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
269 if visible is None: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
270 raise TypeError('visible argument is required when self.input is None') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
271 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
272 if steps is None: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
273 steps = self.gibbs_1 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
274 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
275 if persistent is None: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
276 chain_start = visible |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
277 else: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
278 chain_start = persistent |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
279 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
280 chain_end_mean, chain_end_sample = self.gibbs_k(chain_start, steps) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
281 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
282 #print >> sys.stderr, "WARNING: DEBUGGING with wrong FREE ENERGY" |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
283 #free_energy_delta = - self.free_energy(chain_end_sample) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
284 free_energy_delta = self.free_energy(visible) - self.free_energy(chain_end_sample) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
285 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
286 # we will return all of these regardless of what is in self.params |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
287 all_params = [self.W, self.hbias, self.vbias] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
288 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
289 gparams = T.grad(free_energy_delta, all_params, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
290 consider_constant = [chain_end_sample]) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
291 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
292 cross_entropy = T.mean(T.sum( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
293 visible*T.log(chain_end_mean) + (1 - visible)*T.log(1-chain_end_mean), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
294 axis = 1)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
295 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
296 return (cross_entropy, chain_end_sample,) + tuple(gparams) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
297 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
298 def cd_updates(self, lr, visible = None, persistent = None, steps = 1): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
299 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
300 Return the learning updates for the RBM parameters that are shared variables. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
301 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
302 Also returns an update for the persistent if it is a shared variable. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
303 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
304 These updates are returned as a dictionary. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
305 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
306 :param lr: [scalar] learning rate for contrastive divergence learning |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
307 :param visible: see `cd_grad` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
308 :param persistent: see `cd_grad` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
309 :param steps: see `cd_grad` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
310 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
311 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
312 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
313 cross_entropy, chain_end, gW, ghbias, gvbias = self.cd(visible, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
314 persistent, steps) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
315 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
316 updates = {} |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
317 if hasattr(self.W, 'value'): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
318 updates[self.W] = self.W - lr * gW |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
319 if hasattr(self.hbias, 'value'): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
320 updates[self.hbias] = self.hbias - lr * ghbias |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
321 if hasattr(self.vbias, 'value'): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
322 updates[self.vbias] = self.vbias - lr * gvbias |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
323 if persistent: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
324 #if persistent is a shared var, then it means we should use |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
325 updates[persistent] = chain_end |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
326 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
327 return updates |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
328 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
329 # DEEP MODELS |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
330 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
331 class DBN(object): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
332 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
333 *** WHAT IS A DBN? |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
334 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
335 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
336 def __init__(self, input_len, hidden_layers_sizes, n_classes, rng): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
337 """ This class is made to support a variable number of layers. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
338 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
339 :param train_set_x: symbolic variable pointing to the training dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
340 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
341 :param train_set_y: symbolic variable pointing to the labels of the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
342 training dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
343 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
344 :param input_len: dimension of the input to the sdA |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
345 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
346 :param n_layers_sizes: intermidiate layers size, must contain |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
347 at least one value |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
348 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
349 :param n_classes: dimension of the output of the network |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
350 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
351 :param corruption_levels: amount of corruption to use for each |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
352 layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
353 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
354 :param rng: numpy random number generator used to draw initial weights |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
355 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
356 :param pretrain_lr: learning rate used during pre-trainnig stage |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
357 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
358 :param finetune_lr: learning rate used during finetune stage |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
359 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
360 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
361 self.sigmoid_layers = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
362 self.rbm_layers = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
363 self.pretrain_functions = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
364 self.params = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
365 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
366 theano_rng = RandomStreams(rng.randint(2**30)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
367 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
368 # allocate symbolic variables for the data |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
369 index = T.lscalar() # index to a [mini]batch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
370 self.x = T.matrix('x') # the data is presented as rasterized images |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
371 self.y = T.ivector('y') # the labels are presented as 1D vector of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
372 # [int] labels |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
373 input = self.x |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
374 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
375 # The SdA is an MLP, for which all weights of intermidiate layers |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
376 # are shared with a different denoising autoencoders |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
377 # We will first construct the SdA as a deep multilayer perceptron, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
378 # and when constructing each sigmoidal layer we also construct a |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
379 # denoising autoencoder that shares weights with that layer, and |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
380 # compile a training function for that denoising autoencoder |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
381 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
382 for n_hid in hidden_layers_sizes: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
383 # construct the sigmoidal layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
384 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
385 sigmoid_layer = SigmoidalLayer(rng, input, input_len, n_hid) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
386 self.sigmoid_layers.append(sigmoid_layer) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
387 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
388 self.rbm_layers.append(RBM(input=input, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
389 W=sigmoid_layer.W, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
390 hbias=sigmoid_layer.b, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
391 n_visible = input_len, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
392 n_hidden = n_hid, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
393 numpy_rng=rng, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
394 theano_rng=theano_rng)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
395 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
396 # its arguably a philosophical question... |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
397 # but we are going to only declare that the parameters of the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
398 # sigmoid_layers are parameters of the StackedDAA |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
399 # the hidden-layer biases in the daa_layers are parameters of those |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
400 # daa_layers, but not the StackedDAA |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
401 self.params.extend(self.sigmoid_layers[-1].params) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
402 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
403 # get ready for the next loop iteration |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
404 input_len = n_hid |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
405 input = self.sigmoid_layers[-1].output |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
406 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
407 # We now need to add a logistic layer on top of the MLP |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
408 self.logistic_regressor = LogisticRegression(input = input, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
409 n_in = input_len, n_out = n_classes) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
410 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
411 self.params.extend(self.logistic_regressor.params) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
412 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
413 def pretraining_functions(self, train_set_x, batch_size, learning_rate, k=1): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
414 if k!=1: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
415 raise NotImplementedError() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
416 index = T.lscalar() # index to a [mini]batch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
417 n_train_batches = train_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
418 batch_begin = (index % n_train_batches) * batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
419 batch_end = batch_begin+batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
420 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
421 print 'TRAIN_SET X', train_set_x.value.shape |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
422 rval = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
423 for rbm in self.rbm_layers: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
424 # N.B. these cd() samples are independent from the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
425 # samples used for learning |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
426 outputs = list(rbm.cd())[0:2] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
427 rval.append(function([index], outputs, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
428 updates = rbm.cd_updates(lr=learning_rate), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
429 givens = {self.x: train_set_x[batch_begin:batch_end]})) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
430 if rbm is self.rbm_layers[0]: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
431 f = rval[-1] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
432 AA=len(outputs) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
433 for i, implicit_out in enumerate(f.maker.env.outputs): #[len(outputs):]: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
434 print 'OUTPUT ', i |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
435 theano.printing.debugprint(implicit_out, file=sys.stdout) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
436 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
437 return rval |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
438 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
439 def finetune(self, datasets, lr, batch_size): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
440 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
441 # unpack the various datasets |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
442 (train_set_x, train_set_y) = datasets[0] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
443 (valid_set_x, valid_set_y) = datasets[1] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
444 (test_set_x, test_set_y) = datasets[2] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
445 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
446 # compute number of minibatches for training, validation and testing |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
447 assert train_set_x.value.shape[0] % batch_size == 0 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
448 assert valid_set_x.value.shape[0] % batch_size == 0 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
449 assert test_set_x.value.shape[0] % batch_size == 0 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
450 n_train_batches = train_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
451 n_valid_batches = valid_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
452 n_test_batches = test_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
453 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
454 index = T.lscalar() # index to a [mini]batch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
455 target = self.y |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
456 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
457 train_index = index % n_train_batches |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
458 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
459 classifier = self.logistic_regressor |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
460 cost = classifier.negative_log_likelihood(target) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
461 # compute the gradients with respect to the model parameters |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
462 gparams = T.grad(cost, self.params) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
463 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
464 # compute list of fine-tuning updates |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
465 updates = [(param, param - gparam*finetune_lr) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
466 for param,gparam in zip(self.params, gparams)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
467 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
468 train_fn = theano.function([index], cost, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
469 updates = updates, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
470 givens = { |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
471 self.x : train_set_x[train_index*batch_size:(train_index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
472 target : train_set_y[train_index*batch_size:(train_index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
473 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
474 test_score_i = theano.function([index], classifier.errors(target), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
475 givens = { |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
476 self.x: test_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
477 target: test_set_y[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
478 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
479 valid_score_i = theano.function([index], classifier.errors(target), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
480 givens = { |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
481 self.x: valid_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
482 target: valid_set_y[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
483 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
484 def test_scores(): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
485 return [test_score_i(i) for i in xrange(n_test_batches)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
486 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
487 def valid_scores(): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
488 return [valid_score_i(i) for i in xrange(n_valid_batches)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
489 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
490 return train_fn, valid_scores, test_scores |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
491 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
492 def load_mnist(filename): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
493 f = gzip.open(filename,'rb') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
494 train_set, valid_set, test_set = cPickle.load(f) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
495 f.close() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
496 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
497 def shared_dataset(data_xy): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
498 data_x, data_y = data_xy |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
499 shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
500 shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
501 return shared_x, T.cast(shared_y, 'int32') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
502 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
503 n_train_examples = train_set[0].shape[0] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
504 datasets = shared_dataset(train_set), shared_dataset(valid_set), shared_dataset(test_set) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
505 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
506 return n_train_examples, datasets |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
507 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
508 def dbn_main(finetune_lr = 0.01, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
509 pretraining_epochs = 10, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
510 pretrain_lr = 0.1, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
511 training_epochs = 1000, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
512 batch_size = 20, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
513 mnist_file='mnist.pkl.gz'): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
514 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
515 Demonstrate stochastic gradient descent optimization for a multilayer perceptron |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
516 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
517 This is demonstrated on MNIST. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
518 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
519 :param learning_rate: learning rate used in the finetune stage |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
520 (factor for the stochastic gradient) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
521 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
522 :param pretraining_epochs: number of epoch to do pretraining |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
523 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
524 :param pretrain_lr: learning rate to be used during pre-training |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
525 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
526 :param n_iter: maximal number of iterations ot run the optimizer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
527 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
528 :param mnist_file: path the the pickled mnist_file |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
529 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
530 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
531 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
532 n_train_examples, train_valid_test = load_mnist(mnist_file) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
533 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
534 print "Creating a Deep Belief Network" |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
535 deep_model = DBN( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
536 input_len=28*28, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
537 hidden_layers_sizes = [500, 150, 100], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
538 n_classes=10, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
539 rng = numpy.random.RandomState()) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
540 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
541 #### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
542 #### Phase 1: Pre-training |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
543 #### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
544 print "Pretraining (unsupervised learning) ..." |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
545 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
546 pretrain_functions = deep_model.pretraining_functions( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
547 batch_size=batch_size, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
548 train_set_x=train_valid_test[0][0], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
549 learning_rate=pretrain_lr, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
550 ) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
551 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
552 start_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
553 for layer_idx, pretrain_fn in enumerate(pretrain_functions): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
554 # go through pretraining epochs |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
555 print 'Pre-training layer %i'% layer_idx |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
556 for i in xrange(pretraining_epochs * n_train_examples / batch_size): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
557 outstuff = pretrain_fn(i) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
558 xe, negsample = outstuff[:2] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
559 print (layer_idx, i, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
560 n_train_examples / batch_size, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
561 float(xe), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
562 'Wmin', deep_model.rbm_layers[0].W.value.min(), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
563 'Wmax', deep_model.rbm_layers[0].W.value.max(), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
564 'vmin', deep_model.rbm_layers[0].vbias.value.min(), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
565 'vmax', deep_model.rbm_layers[0].vbias.value.max(), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
566 #'x>0.3', (input_i>0.3).sum(), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
567 ) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
568 sys.stdout.flush() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
569 if i % 1000 == 0: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
570 PIL.Image.fromarray( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
571 pylearn.io.image_tiling.tile_raster_images(negsample, (28,28), (10,10), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
572 tile_spacing=(1,1))).save('samples_%i_%i.png'%(layer_idx,i)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
573 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
574 PIL.Image.fromarray( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
575 pylearn.io.image_tiling.tile_raster_images( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
576 deep_model.rbm_layers[0].W.value.T, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
577 (28,28), (10,10), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
578 tile_spacing=(1,1))).save('filters_%i_%i.png'%(layer_idx,i)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
579 end_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
580 print 'Pretraining took %f minutes' %((end_time - start_time)/60.) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
581 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
582 return |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
583 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
584 print "Fine tuning (supervised learning) ..." |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
585 train_fn, valid_scores, test_scores =\ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
586 deep_model.finetune_functions(train_valid_test[0][0], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
587 learning_rate=finetune_lr, # the learning rate |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
588 batch_size = batch_size) # number of examples to use at once |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
589 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
590 #### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
591 #### Phase 2: Fine Tuning |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
592 #### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
593 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
594 patience = 10000 # look as this many examples regardless |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
595 patience_increase = 2. # wait this much longer when a new best is |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
596 # found |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
597 improvement_threshold = 0.995 # a relative improvement of this much is |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
598 # considered significant |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
599 validation_frequency = min(n_train_examples, patience/2) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
600 # go through this many |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
601 # minibatche before checking the network |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
602 # on the validation set; in this case we |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
603 # check every epoch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
604 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
605 patience_max = n_train_examples * training_epochs |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
606 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
607 best_epoch = None |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
608 best_epoch_test_score = None |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
609 best_epoch_valid_score = float('inf') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
610 start_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
611 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
612 for i in xrange(patience_max): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
613 if i >= patience: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
614 break |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
615 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
616 cost_i = train_fn(i) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
617 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
618 if i % validation_frequency == 0: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
619 validation_i = numpy.mean([score for score in valid_scores()]) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
620 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
621 # if we got the best validation score until now |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
622 if validation_i < best_epoch_valid_score: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
623 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
624 # improve patience if loss improvement is good enough |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
625 threshold_i = best_epoch_valid_score * improvement_threshold |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
626 if validation_i < threshold_i: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
627 patience = max(patience, i * patience_increase) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
628 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
629 # save best validation score and iteration number |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
630 best_epoch_valid_score = validation_i |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
631 best_epoch = i/validation_i |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
632 best_epoch_test_score = numpy.mean( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
633 [score for score in test_scores()]) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
634 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
635 print('epoch %i, validation error %f %%, test error %f %%'%( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
636 i/validation_frequency, validation_i*100., |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
637 best_epoch_test_score*100.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
638 else: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
639 print('epoch %i, validation error %f %%' % ( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
640 i/validation_frequency, validation_i*100.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
641 end_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
642 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
643 print(('Optimization complete with best validation score of %f %%,' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
644 'with test performance %f %%') % |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
645 (finetune_status['best_validation_loss']*100., |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
646 finetune_status['test_score']*100.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
647 print ('The code ran for %f minutes' % ((finetune_status['duration'])/60.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
648 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
649 def rbm_main(): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
650 rbm = RBM(n_visible=20, n_hidden=30, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
651 numpy_rng = numpy.random.RandomState(34)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
652 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
653 cd_updates = rbm.cd_updates(lr=0.25) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
654 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
655 print cd_updates |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
656 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
657 f = function([rbm.input], [], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
658 updates={rbm.W:cd_updates[rbm.W]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
659 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
660 theano.printing.debugprint(f.maker.env.outputs[0], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
661 file=sys.stdout) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
662 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
663 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
664 if __name__ == '__main__': |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
665 dbn_main() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
666 #rbm_main() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
667 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
668 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
669 if 0: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
670 class DAA(object): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
671 def __init__(self, n_visible= 784, n_hidden= 500, corruption_level = 0.1,\ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
672 input = None, shared_W = None, shared_b = None): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
673 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
674 Initialize the dA class by specifying the number of visible units (the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
675 dimension d of the input ), the number of hidden units ( the dimension |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
676 d' of the latent or hidden space ) and the corruption level. The |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
677 constructor also receives symbolic variables for the input, weights and |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
678 bias. Such a symbolic variables are useful when, for example the input is |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
679 the result of some computations, or when weights are shared between the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
680 dA and an MLP layer. When dealing with SdAs this always happens, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
681 the dA on layer 2 gets as input the output of the dA on layer 1, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
682 and the weights of the dA are used in the second stage of training |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
683 to construct an MLP. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
684 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
685 :param n_visible: number of visible units |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
686 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
687 :param n_hidden: number of hidden units |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
688 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
689 :param input: a symbolic description of the input or None |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
690 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
691 :param corruption_level: the corruption mechanism picks up randomly this |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
692 fraction of entries of the input and turns them to 0 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
693 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
694 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
695 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
696 self.n_visible = n_visible |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
697 self.n_hidden = n_hidden |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
698 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
699 # create a Theano random generator that gives symbolic random values |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
700 theano_rng = RandomStreams() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
701 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
702 if shared_W != None and shared_b != None : |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
703 self.W = shared_W |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
704 self.b = shared_b |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
705 else: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
706 # initial values for weights and biases |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
707 # note : W' was written as `W_prime` and b' as `b_prime` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
708 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
709 # W is initialized with `initial_W` which is uniformely sampled |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
710 # from -6./sqrt(n_visible+n_hidden) and 6./sqrt(n_hidden+n_visible) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
711 # the output of uniform if converted using asarray to dtype |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
712 # theano.config.floatX so that the code is runable on GPU |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
713 initial_W = numpy.asarray( numpy.random.uniform( \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
714 low = -numpy.sqrt(6./(n_hidden+n_visible)), \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
715 high = numpy.sqrt(6./(n_hidden+n_visible)), \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
716 size = (n_visible, n_hidden)), dtype = theano.config.floatX) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
717 initial_b = numpy.zeros(n_hidden, dtype = theano.config.floatX) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
718 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
719 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
720 # theano shared variables for weights and biases |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
721 self.W = theano.shared(value = initial_W, name = "W") |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
722 self.b = theano.shared(value = initial_b, name = "b") |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
723 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
724 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
725 initial_b_prime= numpy.zeros(n_visible) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
726 # tied weights, therefore W_prime is W transpose |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
727 self.W_prime = self.W.T |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
728 self.b_prime = theano.shared(value = initial_b_prime, name = "b'") |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
729 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
730 # if no input is given, generate a variable representing the input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
731 if input == None : |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
732 # we use a matrix because we expect a minibatch of several examples, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
733 # each example being a row |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
734 self.x = T.matrix(name = 'input') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
735 else: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
736 self.x = input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
737 # Equation (1) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
738 # keep 90% of the inputs the same and zero-out randomly selected subset of 10% of the inputs |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
739 # note : first argument of theano.rng.binomial is the shape(size) of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
740 # random numbers that it should produce |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
741 # second argument is the number of trials |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
742 # third argument is the probability of success of any trial |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
743 # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
744 # this will produce an array of 0s and 1s where 1 has a |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
745 # probability of 1 - ``corruption_level`` and 0 with |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
746 # ``corruption_level`` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
747 self.tilde_x = theano_rng.binomial( self.x.shape, 1, 1 - corruption_level) * self.x |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
748 # Equation (2) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
749 # note : y is stored as an attribute of the class so that it can be |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
750 # used later when stacking dAs. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
751 self.y = T.nnet.sigmoid(T.dot(self.tilde_x, self.W ) + self.b) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
752 # Equation (3) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
753 self.z = T.nnet.sigmoid(T.dot(self.y, self.W_prime) + self.b_prime) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
754 # Equation (4) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
755 # note : we sum over the size of a datapoint; if we are using minibatches, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
756 # L will be a vector, with one entry per example in minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
757 self.L = - T.sum( self.x*T.log(self.z) + (1-self.x)*T.log(1-self.z), axis=1 ) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
758 # note : L is now a vector, where each element is the cross-entropy cost |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
759 # of the reconstruction of the corresponding example of the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
760 # minibatch. We need to compute the average of all these to get |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
761 # the cost of the minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
762 self.cost = T.mean(self.L) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
763 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
764 self.params = [ self.W, self.b, self.b_prime ] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
765 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
766 class StackedDAA(DeepLayerwiseModel): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
767 """Stacked denoising auto-encoder class (SdA) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
768 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
769 A stacked denoising autoencoder model is obtained by stacking several |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
770 dAs. The hidden layer of the dA at layer `i` becomes the input of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
771 the dA at layer `i+1`. The first layer dA gets as input the input of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
772 the SdA, and the hidden layer of the last dA represents the output. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
773 Note that after pretraining, the SdA is dealt with as a normal MLP, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
774 the dAs are only used to initialize the weights. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
775 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
776 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
777 def __init__(self, n_ins, hidden_layers_sizes, n_outs, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
778 corruption_levels, rng, ): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
779 """ This class is made to support a variable number of layers. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
780 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
781 :param train_set_x: symbolic variable pointing to the training dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
782 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
783 :param train_set_y: symbolic variable pointing to the labels of the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
784 training dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
785 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
786 :param n_ins: dimension of the input to the sdA |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
787 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
788 :param n_layers_sizes: intermidiate layers size, must contain |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
789 at least one value |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
790 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
791 :param n_outs: dimension of the output of the network |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
792 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
793 :param corruption_levels: amount of corruption to use for each |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
794 layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
795 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
796 :param rng: numpy random number generator used to draw initial weights |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
797 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
798 :param pretrain_lr: learning rate used during pre-trainnig stage |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
799 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
800 :param finetune_lr: learning rate used during finetune stage |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
801 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
802 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
803 self.sigmoid_layers = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
804 self.daa_layers = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
805 self.pretrain_functions = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
806 self.params = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
807 self.n_layers = len(hidden_layers_sizes) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
808 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
809 if len(hidden_layers_sizes) < 1 : |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
810 raiseException (' You must have at least one hidden layer ') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
811 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
812 theano_rng = RandomStreams(rng.randint(2**30)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
813 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
814 # allocate symbolic variables for the data |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
815 index = T.lscalar() # index to a [mini]batch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
816 self.x = T.matrix('x') # the data is presented as rasterized images |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
817 self.y = T.ivector('y') # the labels are presented as 1D vector of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
818 # [int] labels |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
819 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
820 # The SdA is an MLP, for which all weights of intermidiate layers |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
821 # are shared with a different denoising autoencoders |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
822 # We will first construct the SdA as a deep multilayer perceptron, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
823 # and when constructing each sigmoidal layer we also construct a |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
824 # denoising autoencoder that shares weights with that layer, and |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
825 # compile a training function for that denoising autoencoder |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
826 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
827 for i in xrange( self.n_layers ): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
828 # construct the sigmoidal layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
829 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
830 sigmoid_layer = SigmoidalLayer(rng, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
831 self.layers[-1].output if i else self.x, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
832 hidden_layers_sizes[i-1] if i else n_ins, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
833 hidden_layers_sizes[i]) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
834 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
835 daa_layer = DAA(corruption_level = corruption_levels[i], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
836 input = sigmoid_layer.input, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
837 W = sigmoid_layer.W, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
838 b = sigmoid_layer.b) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
839 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
840 # add the layer to the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
841 self.sigmoid_layers.append(sigmoid_layer) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
842 self.daa_layers.append(daa_layer) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
843 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
844 # its arguably a philosophical question... |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
845 # but we are going to only declare that the parameters of the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
846 # sigmoid_layers are parameters of the StackedDAA |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
847 # the hidden-layer biases in the daa_layers are parameters of those |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
848 # daa_layers, but not the StackedDAA |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
849 self.params.extend(sigmoid_layer.params) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
850 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
851 # We now need to add a logistic layer on top of the MLP |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
852 self.logistic_regressor = LogisticRegression( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
853 input = self.sigmoid_layers[-1].output, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
854 n_in = hidden_layers_sizes[-1], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
855 n_out = n_outs) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
856 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
857 self.params.extend(self.logLayer.params) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
858 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
859 def pretraining_functions(self, train_set_x, batch_size): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
860 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
861 # compiles update functions for each layer, and |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
862 # returns them as a list |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
863 # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
864 # Construct a function that trains this dA |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
865 # compute gradients of layer parameters |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
866 gparams = T.grad(dA_layer.cost, dA_layer.params) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
867 # compute the list of updates |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
868 updates = {} |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
869 for param, gparam in zip(dA_layer.params, gparams): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
870 updates[param] = param - gparam * pretrain_lr |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
871 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
872 # create a function that trains the dA |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
873 update_fn = theano.function([index], dA_layer.cost, \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
874 updates = updates, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
875 givens = { |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
876 self.x : train_set_x[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
877 # collect this function into a list |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
878 self.pretrain_functions += [update_fn] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
879 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
880 |