annotate code_tutoriel/deep.py @ 165:4bc5eeec6394

Updating the tutorial code to the latest revisions.
author Dumitru Erhan <dumitru.erhan@gmail.com>
date Fri, 26 Feb 2010 13:55:27 -0500
parents
children
rev   line source
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
1 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
2 Draft of DBN, DAA, SDAA, RBM tutorial code
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
3
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
4 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
5 import sys
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
6 import numpy
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
7 import theano
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
8 import time
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
9 import theano.tensor as T
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
10 from theano.tensor.shared_randomstreams import RandomStreams
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
11 from theano import shared, function
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
12
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
13 import gzip
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
14 import cPickle
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
15 import pylearn.io.image_tiling
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
16 import PIL
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
17
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
18 # NNET STUFF
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
19
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
20 class LogisticRegression(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
21 """Multi-class Logistic Regression Class
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
22
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
23 The logistic regression is fully described by a weight matrix :math:`W`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
24 and bias vector :math:`b`. Classification is done by projecting data
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
25 points onto a set of hyperplanes, the distance to which is used to
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
26 determine a class membership probability.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
27 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
28
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
29 def __init__(self, input, n_in, n_out):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
30 """ Initialize the parameters of the logistic regression
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
31 :param input: symbolic variable that describes the input of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
32 architecture (one minibatch)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
33 :type n_in: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
34 :param n_in: number of input units, the dimension of the space in
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
35 which the datapoints lie
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
36 :type n_out: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
37 :param n_out: number of output units, the dimension of the space in
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
38 which the labels lie
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
39 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
40
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
41 # initialize with 0 the weights W as a matrix of shape (n_in, n_out)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
42 self.W = theano.shared( value=numpy.zeros((n_in,n_out),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
43 dtype = theano.config.floatX) )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
44 # initialize the baises b as a vector of n_out 0s
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
45 self.b = theano.shared( value=numpy.zeros((n_out,),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
46 dtype = theano.config.floatX) )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
47 # compute vector of class-membership probabilities in symbolic form
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
48 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W)+self.b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
49
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
50 # compute prediction as class whose probability is maximal in
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
51 # symbolic form
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
52 self.y_pred=T.argmax(self.p_y_given_x, axis=1)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
53
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
54 # list of parameters for this layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
55 self.params = [self.W, self.b]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
56
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
57 def negative_log_likelihood(self, y):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
58 """Return the mean of the negative log-likelihood of the prediction
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
59 of this model under a given target distribution.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
60 :param y: corresponds to a vector that gives for each example the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
61 correct label
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
62 Note: we use the mean instead of the sum so that
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
63 the learning rate is less dependent on the batch size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
64 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
65 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
66
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
67 def errors(self, y):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
68 """Return a float representing the number of errors in the minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
69 over the total number of examples of the minibatch ; zero one
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
70 loss over the size of the minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
71 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
72 # check if y has same dimension of y_pred
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
73 if y.ndim != self.y_pred.ndim:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
74 raise TypeError('y should have the same shape as self.y_pred',
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
75 ('y', target.type, 'y_pred', self.y_pred.type))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
76
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
77 # check if y is of the correct datatype
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
78 if y.dtype.startswith('int'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
79 # the T.neq operator returns a vector of 0s and 1s, where 1
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
80 # represents a mistake in prediction
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
81 return T.mean(T.neq(self.y_pred, y))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
82 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
83 raise NotImplementedError()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
84
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
85 class SigmoidalLayer(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
86 def __init__(self, rng, input, n_in, n_out):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
87 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
88 Typical hidden layer of a MLP: units are fully-connected and have
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
89 sigmoidal activation function. Weight matrix W is of shape (n_in,n_out)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
90 and the bias vector b is of shape (n_out,).
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
91
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
92 Hidden unit activation is given by: sigmoid(dot(input,W) + b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
93
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
94 :type rng: numpy.random.RandomState
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
95 :param rng: a random number generator used to initialize weights
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
96 :type input: theano.tensor.matrix
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
97 :param input: a symbolic tensor of shape (n_examples, n_in)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
98 :type n_in: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
99 :param n_in: dimensionality of input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
100 :type n_out: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
101 :param n_out: number of hidden units
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
102 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
103 self.input = input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
104
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
105 W_values = numpy.asarray( rng.uniform( \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
106 low = -numpy.sqrt(6./(n_in+n_out)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
107 high = numpy.sqrt(6./(n_in+n_out)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
108 size = (n_in, n_out)), dtype = theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
109 self.W = theano.shared(value = W_values)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
110
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
111 b_values = numpy.zeros((n_out,), dtype= theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
112 self.b = theano.shared(value= b_values)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
113
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
114 self.output = T.nnet.sigmoid(T.dot(input, self.W) + self.b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
115 self.params = [self.W, self.b]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
116
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
117 # PRETRAINING LAYERS
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
118
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
119 class RBM(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
120 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
121 *** WRITE THE ENERGY FUNCTION USE SAME LETTERS AS VARIABLE NAMES IN CODE
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
122 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
123
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
124 def __init__(self, input=None, n_visible=None, n_hidden=None,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
125 W=None, hbias=None, vbias=None,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
126 numpy_rng=None, theano_rng=None):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
127 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
128 RBM constructor. Defines the parameters of the model along with
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
129 basic operations for inferring hidden from visible (and vice-versa),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
130 as well as for performing CD updates.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
131
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
132 :param input: None for standalone RBMs or symbolic variable if RBM is
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
133 part of a larger graph.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
134
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
135 :param n_visible: number of visible units (necessary when W or vbias is None)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
136
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
137 :param n_hidden: number of hidden units (necessary when W or hbias is None)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
138
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
139 :param W: weights to use for the RBM. None means that a shared variable will be
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
140 created with a randomly chosen matrix of size (n_visible, n_hidden).
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
141
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
142 :param hbias: ***
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
143
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
144 :param vbias: ***
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
145
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
146 :param numpy_rng: random number generator (necessary when W is None)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
147
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
148 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
149
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
150 params = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
151 if W is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
152 # choose initial values for weight matrix of RBM
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
153 initial_W = numpy.asarray(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
154 numpy_rng.uniform( \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
155 low=-numpy.sqrt(6./(n_hidden+n_visible)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
156 high=numpy.sqrt(6./(n_hidden+n_visible)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
157 size=(n_visible, n_hidden)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
158 dtype=theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
159 W = theano.shared(value=initial_W, name='W')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
160 params.append(W)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
161
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
162 if hbias is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
163 # theano shared variables for hidden biases
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
164 hbias = theano.shared(value=numpy.zeros(n_hidden,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
165 dtype=theano.config.floatX), name='hbias')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
166 params.append(hbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
167
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
168 if vbias is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
169 # theano shared variables for visible biases
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
170 vbias = theano.shared(value=numpy.zeros(n_visible,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
171 dtype=theano.config.floatX), name='vbias')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
172 params.append(vbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
173
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
174 if input is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
175 # initialize input layer for standalone RBM or layer0 of DBN
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
176 input = T.matrix('input')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
177
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
178 # setup theano random number generator
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
179 if theano_rng is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
180 theano_rng = RandomStreams(numpy_rng.randint(2**30))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
181
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
182 self.visible = self.input = input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
183 self.W = W
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
184 self.hbias = hbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
185 self.vbias = vbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
186 self.theano_rng = theano_rng
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
187 self.params = params
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
188 self.hidden_mean = T.nnet.sigmoid(T.dot(input, W)+hbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
189 self.hidden_sample = theano_rng.binomial(self.hidden_mean.shape, 1, self.hidden_mean)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
190
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
191 def gibbs_k(self, v_sample, k):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
192 ''' This function implements k steps of Gibbs sampling '''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
193
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
194 # We compute the visible after k steps of Gibbs by iterating
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
195 # over ``gibs_1`` for k times; this can be done in Theano using
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
196 # the `scan op`. For a more comprehensive description of scan see
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
197 # http://deeplearning.net/software/theano/library/scan.html .
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
198
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
199 def gibbs_1(v0_sample, W, hbias, vbias):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
200 ''' This function implements one Gibbs step '''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
201
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
202 # compute the activation of the hidden units given a sample of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
203 # vissibles
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
204 h0_mean = T.nnet.sigmoid(T.dot(v0_sample, W) + hbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
205 # get a sample of the hiddens given their activation
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
206 h0_sample = self.theano_rng.binomial(h0_mean.shape, 1, h0_mean)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
207 # compute the activation of the visible given the hidden sample
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
208 v1_mean = T.nnet.sigmoid(T.dot(h0_sample, W.T) + vbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
209 # get a sample of the visible given their activation
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
210 v1_act = self.theano_rng.binomial(v1_mean.shape, 1, v1_mean)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
211 return [v1_mean, v1_act]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
212
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
213
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
214 # DEBUGGING TO DO ALL WITHOUT SCAN
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
215 if k == 1:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
216 return gibbs_1(v_sample, self.W, self.hbias, self.vbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
217
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
218
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
219 # Because we require as output two values, namely the mean field
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
220 # approximation of the visible and the sample obtained after k steps,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
221 # scan needs to know the shape of those two outputs. Scan takes
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
222 # this information from the variables containing the initial state
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
223 # of the outputs. Since we do not need a initial state of ``v_mean``
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
224 # we provide a dummy one used only to get the correct shape
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
225 v_mean = T.zeros_like(v_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
226
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
227 # ``outputs_taps`` is an argument of scan which describes at each
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
228 # time step what past values of the outputs the function applied
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
229 # recursively needs. This is given in the form of a dictionary,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
230 # where the keys are outputs indexes, and values are a list of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
231 # of the offsets used by the corresponding outputs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
232 # In our case the function ``gibbs_1`` applied recursively, requires
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
233 # at time k the past value k-1 for the first output (index 0) and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
234 # no past value of the second output
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
235 outputs_taps = { 0 : [-1], 1 : [] }
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
236
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
237 v_means, v_samples = theano.scan( fn = gibbs_1,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
238 sequences = [],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
239 initial_states = [v_sample, v_mean],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
240 non_sequences = [self.W, self.hbias, self.vbias],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
241 outputs_taps = outputs_taps,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
242 n_steps = k)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
243 return v_means[-1], v_samples[-1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
244
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
245 def free_energy(self, v_sample):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
246 wx_b = T.dot(v_sample, self.W) + self.hbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
247 vbias_term = T.sum(T.dot(v_sample, self.vbias))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
248 hidden_term = T.sum(T.log(1+T.exp(wx_b)))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
249 return -hidden_term - vbias_term
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
250
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
251 def cd(self, visible = None, persistent = None, steps = 1):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
252 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
253 Return a 5-tuple of values related to contrastive divergence: (cost,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
254 end-state of negative-phase chain, gradient on weights, gradient on
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
255 hidden bias, gradient on visible bias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
256
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
257 If visible is None, it defaults to self.input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
258 If persistent is None, it defaults to self.input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
259
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
260 CD aka CD1 - cd()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
261 CD-10 - cd(steps=10)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
262 PCD - cd(persistent=shared(numpy.asarray(initializer)))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
263 PCD-k - cd(persistent=shared(numpy.asarray(initializer)),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
264 steps=10)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
265 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
266 if visible is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
267 visible = self.input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
268
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
269 if visible is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
270 raise TypeError('visible argument is required when self.input is None')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
271
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
272 if steps is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
273 steps = self.gibbs_1
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
274
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
275 if persistent is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
276 chain_start = visible
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
277 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
278 chain_start = persistent
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
279
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
280 chain_end_mean, chain_end_sample = self.gibbs_k(chain_start, steps)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
281
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
282 #print >> sys.stderr, "WARNING: DEBUGGING with wrong FREE ENERGY"
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
283 #free_energy_delta = - self.free_energy(chain_end_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
284 free_energy_delta = self.free_energy(visible) - self.free_energy(chain_end_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
285
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
286 # we will return all of these regardless of what is in self.params
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
287 all_params = [self.W, self.hbias, self.vbias]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
288
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
289 gparams = T.grad(free_energy_delta, all_params,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
290 consider_constant = [chain_end_sample])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
291
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
292 cross_entropy = T.mean(T.sum(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
293 visible*T.log(chain_end_mean) + (1 - visible)*T.log(1-chain_end_mean),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
294 axis = 1))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
295
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
296 return (cross_entropy, chain_end_sample,) + tuple(gparams)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
297
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
298 def cd_updates(self, lr, visible = None, persistent = None, steps = 1):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
299 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
300 Return the learning updates for the RBM parameters that are shared variables.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
301
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
302 Also returns an update for the persistent if it is a shared variable.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
303
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
304 These updates are returned as a dictionary.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
305
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
306 :param lr: [scalar] learning rate for contrastive divergence learning
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
307 :param visible: see `cd_grad`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
308 :param persistent: see `cd_grad`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
309 :param steps: see `cd_grad`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
310
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
311 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
312
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
313 cross_entropy, chain_end, gW, ghbias, gvbias = self.cd(visible,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
314 persistent, steps)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
315
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
316 updates = {}
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
317 if hasattr(self.W, 'value'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
318 updates[self.W] = self.W - lr * gW
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
319 if hasattr(self.hbias, 'value'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
320 updates[self.hbias] = self.hbias - lr * ghbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
321 if hasattr(self.vbias, 'value'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
322 updates[self.vbias] = self.vbias - lr * gvbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
323 if persistent:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
324 #if persistent is a shared var, then it means we should use
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
325 updates[persistent] = chain_end
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
326
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
327 return updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
328
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
329 # DEEP MODELS
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
330
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
331 class DBN(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
332 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
333 *** WHAT IS A DBN?
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
334 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
335
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
336 def __init__(self, input_len, hidden_layers_sizes, n_classes, rng):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
337 """ This class is made to support a variable number of layers.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
338
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
339 :param train_set_x: symbolic variable pointing to the training dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
340
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
341 :param train_set_y: symbolic variable pointing to the labels of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
342 training dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
343
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
344 :param input_len: dimension of the input to the sdA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
345
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
346 :param n_layers_sizes: intermidiate layers size, must contain
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
347 at least one value
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
348
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
349 :param n_classes: dimension of the output of the network
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
350
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
351 :param corruption_levels: amount of corruption to use for each
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
352 layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
353
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
354 :param rng: numpy random number generator used to draw initial weights
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
355
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
356 :param pretrain_lr: learning rate used during pre-trainnig stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
357
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
358 :param finetune_lr: learning rate used during finetune stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
359 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
360
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
361 self.sigmoid_layers = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
362 self.rbm_layers = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
363 self.pretrain_functions = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
364 self.params = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
365
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
366 theano_rng = RandomStreams(rng.randint(2**30))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
367
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
368 # allocate symbolic variables for the data
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
369 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
370 self.x = T.matrix('x') # the data is presented as rasterized images
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
371 self.y = T.ivector('y') # the labels are presented as 1D vector of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
372 # [int] labels
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
373 input = self.x
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
374
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
375 # The SdA is an MLP, for which all weights of intermidiate layers
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
376 # are shared with a different denoising autoencoders
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
377 # We will first construct the SdA as a deep multilayer perceptron,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
378 # and when constructing each sigmoidal layer we also construct a
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
379 # denoising autoencoder that shares weights with that layer, and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
380 # compile a training function for that denoising autoencoder
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
381
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
382 for n_hid in hidden_layers_sizes:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
383 # construct the sigmoidal layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
384
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
385 sigmoid_layer = SigmoidalLayer(rng, input, input_len, n_hid)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
386 self.sigmoid_layers.append(sigmoid_layer)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
387
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
388 self.rbm_layers.append(RBM(input=input,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
389 W=sigmoid_layer.W,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
390 hbias=sigmoid_layer.b,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
391 n_visible = input_len,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
392 n_hidden = n_hid,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
393 numpy_rng=rng,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
394 theano_rng=theano_rng))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
395
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
396 # its arguably a philosophical question...
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
397 # but we are going to only declare that the parameters of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
398 # sigmoid_layers are parameters of the StackedDAA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
399 # the hidden-layer biases in the daa_layers are parameters of those
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
400 # daa_layers, but not the StackedDAA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
401 self.params.extend(self.sigmoid_layers[-1].params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
402
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
403 # get ready for the next loop iteration
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
404 input_len = n_hid
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
405 input = self.sigmoid_layers[-1].output
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
406
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
407 # We now need to add a logistic layer on top of the MLP
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
408 self.logistic_regressor = LogisticRegression(input = input,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
409 n_in = input_len, n_out = n_classes)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
410
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
411 self.params.extend(self.logistic_regressor.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
412
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
413 def pretraining_functions(self, train_set_x, batch_size, learning_rate, k=1):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
414 if k!=1:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
415 raise NotImplementedError()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
416 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
417 n_train_batches = train_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
418 batch_begin = (index % n_train_batches) * batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
419 batch_end = batch_begin+batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
420
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
421 print 'TRAIN_SET X', train_set_x.value.shape
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
422 rval = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
423 for rbm in self.rbm_layers:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
424 # N.B. these cd() samples are independent from the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
425 # samples used for learning
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
426 outputs = list(rbm.cd())[0:2]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
427 rval.append(function([index], outputs,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
428 updates = rbm.cd_updates(lr=learning_rate),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
429 givens = {self.x: train_set_x[batch_begin:batch_end]}))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
430 if rbm is self.rbm_layers[0]:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
431 f = rval[-1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
432 AA=len(outputs)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
433 for i, implicit_out in enumerate(f.maker.env.outputs): #[len(outputs):]:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
434 print 'OUTPUT ', i
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
435 theano.printing.debugprint(implicit_out, file=sys.stdout)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
436
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
437 return rval
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
438
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
439 def finetune(self, datasets, lr, batch_size):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
440
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
441 # unpack the various datasets
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
442 (train_set_x, train_set_y) = datasets[0]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
443 (valid_set_x, valid_set_y) = datasets[1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
444 (test_set_x, test_set_y) = datasets[2]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
445
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
446 # compute number of minibatches for training, validation and testing
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
447 assert train_set_x.value.shape[0] % batch_size == 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
448 assert valid_set_x.value.shape[0] % batch_size == 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
449 assert test_set_x.value.shape[0] % batch_size == 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
450 n_train_batches = train_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
451 n_valid_batches = valid_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
452 n_test_batches = test_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
453
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
454 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
455 target = self.y
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
456
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
457 train_index = index % n_train_batches
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
458
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
459 classifier = self.logistic_regressor
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
460 cost = classifier.negative_log_likelihood(target)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
461 # compute the gradients with respect to the model parameters
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
462 gparams = T.grad(cost, self.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
463
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
464 # compute list of fine-tuning updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
465 updates = [(param, param - gparam*finetune_lr)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
466 for param,gparam in zip(self.params, gparams)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
467
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
468 train_fn = theano.function([index], cost,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
469 updates = updates,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
470 givens = {
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
471 self.x : train_set_x[train_index*batch_size:(train_index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
472 target : train_set_y[train_index*batch_size:(train_index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
473
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
474 test_score_i = theano.function([index], classifier.errors(target),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
475 givens = {
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
476 self.x: test_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
477 target: test_set_y[index*batch_size:(index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
478
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
479 valid_score_i = theano.function([index], classifier.errors(target),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
480 givens = {
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
481 self.x: valid_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
482 target: valid_set_y[index*batch_size:(index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
483
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
484 def test_scores():
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
485 return [test_score_i(i) for i in xrange(n_test_batches)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
486
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
487 def valid_scores():
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
488 return [valid_score_i(i) for i in xrange(n_valid_batches)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
489
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
490 return train_fn, valid_scores, test_scores
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
491
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
492 def load_mnist(filename):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
493 f = gzip.open(filename,'rb')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
494 train_set, valid_set, test_set = cPickle.load(f)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
495 f.close()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
496
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
497 def shared_dataset(data_xy):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
498 data_x, data_y = data_xy
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
499 shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
500 shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
501 return shared_x, T.cast(shared_y, 'int32')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
502
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
503 n_train_examples = train_set[0].shape[0]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
504 datasets = shared_dataset(train_set), shared_dataset(valid_set), shared_dataset(test_set)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
505
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
506 return n_train_examples, datasets
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
507
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
508 def dbn_main(finetune_lr = 0.01,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
509 pretraining_epochs = 10,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
510 pretrain_lr = 0.1,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
511 training_epochs = 1000,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
512 batch_size = 20,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
513 mnist_file='mnist.pkl.gz'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
514 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
515 Demonstrate stochastic gradient descent optimization for a multilayer perceptron
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
516
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
517 This is demonstrated on MNIST.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
518
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
519 :param learning_rate: learning rate used in the finetune stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
520 (factor for the stochastic gradient)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
521
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
522 :param pretraining_epochs: number of epoch to do pretraining
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
523
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
524 :param pretrain_lr: learning rate to be used during pre-training
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
525
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
526 :param n_iter: maximal number of iterations ot run the optimizer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
527
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
528 :param mnist_file: path the the pickled mnist_file
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
529
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
530 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
531
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
532 n_train_examples, train_valid_test = load_mnist(mnist_file)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
533
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
534 print "Creating a Deep Belief Network"
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
535 deep_model = DBN(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
536 input_len=28*28,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
537 hidden_layers_sizes = [500, 150, 100],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
538 n_classes=10,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
539 rng = numpy.random.RandomState())
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
540
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
541 ####
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
542 #### Phase 1: Pre-training
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
543 ####
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
544 print "Pretraining (unsupervised learning) ..."
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
545
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
546 pretrain_functions = deep_model.pretraining_functions(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
547 batch_size=batch_size,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
548 train_set_x=train_valid_test[0][0],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
549 learning_rate=pretrain_lr,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
550 )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
551
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
552 start_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
553 for layer_idx, pretrain_fn in enumerate(pretrain_functions):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
554 # go through pretraining epochs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
555 print 'Pre-training layer %i'% layer_idx
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
556 for i in xrange(pretraining_epochs * n_train_examples / batch_size):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
557 outstuff = pretrain_fn(i)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
558 xe, negsample = outstuff[:2]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
559 print (layer_idx, i,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
560 n_train_examples / batch_size,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
561 float(xe),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
562 'Wmin', deep_model.rbm_layers[0].W.value.min(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
563 'Wmax', deep_model.rbm_layers[0].W.value.max(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
564 'vmin', deep_model.rbm_layers[0].vbias.value.min(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
565 'vmax', deep_model.rbm_layers[0].vbias.value.max(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
566 #'x>0.3', (input_i>0.3).sum(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
567 )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
568 sys.stdout.flush()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
569 if i % 1000 == 0:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
570 PIL.Image.fromarray(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
571 pylearn.io.image_tiling.tile_raster_images(negsample, (28,28), (10,10),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
572 tile_spacing=(1,1))).save('samples_%i_%i.png'%(layer_idx,i))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
573
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
574 PIL.Image.fromarray(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
575 pylearn.io.image_tiling.tile_raster_images(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
576 deep_model.rbm_layers[0].W.value.T,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
577 (28,28), (10,10),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
578 tile_spacing=(1,1))).save('filters_%i_%i.png'%(layer_idx,i))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
579 end_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
580 print 'Pretraining took %f minutes' %((end_time - start_time)/60.)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
581
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
582 return
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
583
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
584 print "Fine tuning (supervised learning) ..."
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
585 train_fn, valid_scores, test_scores =\
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
586 deep_model.finetune_functions(train_valid_test[0][0],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
587 learning_rate=finetune_lr, # the learning rate
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
588 batch_size = batch_size) # number of examples to use at once
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
589
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
590 ####
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
591 #### Phase 2: Fine Tuning
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
592 ####
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
593
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
594 patience = 10000 # look as this many examples regardless
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
595 patience_increase = 2. # wait this much longer when a new best is
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
596 # found
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
597 improvement_threshold = 0.995 # a relative improvement of this much is
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
598 # considered significant
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
599 validation_frequency = min(n_train_examples, patience/2)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
600 # go through this many
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
601 # minibatche before checking the network
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
602 # on the validation set; in this case we
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
603 # check every epoch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
604
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
605 patience_max = n_train_examples * training_epochs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
606
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
607 best_epoch = None
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
608 best_epoch_test_score = None
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
609 best_epoch_valid_score = float('inf')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
610 start_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
611
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
612 for i in xrange(patience_max):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
613 if i >= patience:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
614 break
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
615
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
616 cost_i = train_fn(i)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
617
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
618 if i % validation_frequency == 0:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
619 validation_i = numpy.mean([score for score in valid_scores()])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
620
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
621 # if we got the best validation score until now
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
622 if validation_i < best_epoch_valid_score:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
623
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
624 # improve patience if loss improvement is good enough
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
625 threshold_i = best_epoch_valid_score * improvement_threshold
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
626 if validation_i < threshold_i:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
627 patience = max(patience, i * patience_increase)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
628
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
629 # save best validation score and iteration number
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
630 best_epoch_valid_score = validation_i
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
631 best_epoch = i/validation_i
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
632 best_epoch_test_score = numpy.mean(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
633 [score for score in test_scores()])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
634
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
635 print('epoch %i, validation error %f %%, test error %f %%'%(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
636 i/validation_frequency, validation_i*100.,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
637 best_epoch_test_score*100.))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
638 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
639 print('epoch %i, validation error %f %%' % (
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
640 i/validation_frequency, validation_i*100.))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
641 end_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
642
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
643 print(('Optimization complete with best validation score of %f %%,'
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
644 'with test performance %f %%') %
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
645 (finetune_status['best_validation_loss']*100.,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
646 finetune_status['test_score']*100.))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
647 print ('The code ran for %f minutes' % ((finetune_status['duration'])/60.))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
648
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
649 def rbm_main():
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
650 rbm = RBM(n_visible=20, n_hidden=30,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
651 numpy_rng = numpy.random.RandomState(34))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
652
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
653 cd_updates = rbm.cd_updates(lr=0.25)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
654
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
655 print cd_updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
656
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
657 f = function([rbm.input], [],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
658 updates={rbm.W:cd_updates[rbm.W]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
659
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
660 theano.printing.debugprint(f.maker.env.outputs[0],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
661 file=sys.stdout)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
662
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
663
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
664 if __name__ == '__main__':
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
665 dbn_main()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
666 #rbm_main()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
667
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
668
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
669 if 0:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
670 class DAA(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
671 def __init__(self, n_visible= 784, n_hidden= 500, corruption_level = 0.1,\
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
672 input = None, shared_W = None, shared_b = None):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
673 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
674 Initialize the dA class by specifying the number of visible units (the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
675 dimension d of the input ), the number of hidden units ( the dimension
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
676 d' of the latent or hidden space ) and the corruption level. The
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
677 constructor also receives symbolic variables for the input, weights and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
678 bias. Such a symbolic variables are useful when, for example the input is
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
679 the result of some computations, or when weights are shared between the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
680 dA and an MLP layer. When dealing with SdAs this always happens,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
681 the dA on layer 2 gets as input the output of the dA on layer 1,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
682 and the weights of the dA are used in the second stage of training
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
683 to construct an MLP.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
684
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
685 :param n_visible: number of visible units
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
686
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
687 :param n_hidden: number of hidden units
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
688
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
689 :param input: a symbolic description of the input or None
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
690
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
691 :param corruption_level: the corruption mechanism picks up randomly this
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
692 fraction of entries of the input and turns them to 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
693
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
694
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
695 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
696 self.n_visible = n_visible
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
697 self.n_hidden = n_hidden
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
698
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
699 # create a Theano random generator that gives symbolic random values
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
700 theano_rng = RandomStreams()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
701
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
702 if shared_W != None and shared_b != None :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
703 self.W = shared_W
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
704 self.b = shared_b
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
705 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
706 # initial values for weights and biases
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
707 # note : W' was written as `W_prime` and b' as `b_prime`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
708
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
709 # W is initialized with `initial_W` which is uniformely sampled
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
710 # from -6./sqrt(n_visible+n_hidden) and 6./sqrt(n_hidden+n_visible)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
711 # the output of uniform if converted using asarray to dtype
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
712 # theano.config.floatX so that the code is runable on GPU
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
713 initial_W = numpy.asarray( numpy.random.uniform( \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
714 low = -numpy.sqrt(6./(n_hidden+n_visible)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
715 high = numpy.sqrt(6./(n_hidden+n_visible)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
716 size = (n_visible, n_hidden)), dtype = theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
717 initial_b = numpy.zeros(n_hidden, dtype = theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
718
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
719
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
720 # theano shared variables for weights and biases
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
721 self.W = theano.shared(value = initial_W, name = "W")
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
722 self.b = theano.shared(value = initial_b, name = "b")
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
723
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
724
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
725 initial_b_prime= numpy.zeros(n_visible)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
726 # tied weights, therefore W_prime is W transpose
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
727 self.W_prime = self.W.T
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
728 self.b_prime = theano.shared(value = initial_b_prime, name = "b'")
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
729
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
730 # if no input is given, generate a variable representing the input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
731 if input == None :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
732 # we use a matrix because we expect a minibatch of several examples,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
733 # each example being a row
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
734 self.x = T.matrix(name = 'input')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
735 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
736 self.x = input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
737 # Equation (1)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
738 # keep 90% of the inputs the same and zero-out randomly selected subset of 10% of the inputs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
739 # note : first argument of theano.rng.binomial is the shape(size) of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
740 # random numbers that it should produce
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
741 # second argument is the number of trials
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
742 # third argument is the probability of success of any trial
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
743 #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
744 # this will produce an array of 0s and 1s where 1 has a
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
745 # probability of 1 - ``corruption_level`` and 0 with
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
746 # ``corruption_level``
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
747 self.tilde_x = theano_rng.binomial( self.x.shape, 1, 1 - corruption_level) * self.x
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
748 # Equation (2)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
749 # note : y is stored as an attribute of the class so that it can be
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
750 # used later when stacking dAs.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
751 self.y = T.nnet.sigmoid(T.dot(self.tilde_x, self.W ) + self.b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
752 # Equation (3)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
753 self.z = T.nnet.sigmoid(T.dot(self.y, self.W_prime) + self.b_prime)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
754 # Equation (4)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
755 # note : we sum over the size of a datapoint; if we are using minibatches,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
756 # L will be a vector, with one entry per example in minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
757 self.L = - T.sum( self.x*T.log(self.z) + (1-self.x)*T.log(1-self.z), axis=1 )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
758 # note : L is now a vector, where each element is the cross-entropy cost
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
759 # of the reconstruction of the corresponding example of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
760 # minibatch. We need to compute the average of all these to get
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
761 # the cost of the minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
762 self.cost = T.mean(self.L)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
763
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
764 self.params = [ self.W, self.b, self.b_prime ]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
765
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
766 class StackedDAA(DeepLayerwiseModel):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
767 """Stacked denoising auto-encoder class (SdA)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
768
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
769 A stacked denoising autoencoder model is obtained by stacking several
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
770 dAs. The hidden layer of the dA at layer `i` becomes the input of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
771 the dA at layer `i+1`. The first layer dA gets as input the input of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
772 the SdA, and the hidden layer of the last dA represents the output.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
773 Note that after pretraining, the SdA is dealt with as a normal MLP,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
774 the dAs are only used to initialize the weights.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
775 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
776
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
777 def __init__(self, n_ins, hidden_layers_sizes, n_outs,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
778 corruption_levels, rng, ):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
779 """ This class is made to support a variable number of layers.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
780
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
781 :param train_set_x: symbolic variable pointing to the training dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
782
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
783 :param train_set_y: symbolic variable pointing to the labels of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
784 training dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
785
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
786 :param n_ins: dimension of the input to the sdA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
787
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
788 :param n_layers_sizes: intermidiate layers size, must contain
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
789 at least one value
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
790
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
791 :param n_outs: dimension of the output of the network
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
792
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
793 :param corruption_levels: amount of corruption to use for each
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
794 layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
795
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
796 :param rng: numpy random number generator used to draw initial weights
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
797
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
798 :param pretrain_lr: learning rate used during pre-trainnig stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
799
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
800 :param finetune_lr: learning rate used during finetune stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
801 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
802
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
803 self.sigmoid_layers = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
804 self.daa_layers = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
805 self.pretrain_functions = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
806 self.params = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
807 self.n_layers = len(hidden_layers_sizes)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
808
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
809 if len(hidden_layers_sizes) < 1 :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
810 raiseException (' You must have at least one hidden layer ')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
811
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
812 theano_rng = RandomStreams(rng.randint(2**30))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
813
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
814 # allocate symbolic variables for the data
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
815 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
816 self.x = T.matrix('x') # the data is presented as rasterized images
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
817 self.y = T.ivector('y') # the labels are presented as 1D vector of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
818 # [int] labels
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
819
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
820 # The SdA is an MLP, for which all weights of intermidiate layers
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
821 # are shared with a different denoising autoencoders
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
822 # We will first construct the SdA as a deep multilayer perceptron,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
823 # and when constructing each sigmoidal layer we also construct a
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
824 # denoising autoencoder that shares weights with that layer, and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
825 # compile a training function for that denoising autoencoder
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
826
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
827 for i in xrange( self.n_layers ):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
828 # construct the sigmoidal layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
829
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
830 sigmoid_layer = SigmoidalLayer(rng,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
831 self.layers[-1].output if i else self.x,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
832 hidden_layers_sizes[i-1] if i else n_ins,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
833 hidden_layers_sizes[i])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
834
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
835 daa_layer = DAA(corruption_level = corruption_levels[i],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
836 input = sigmoid_layer.input,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
837 W = sigmoid_layer.W,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
838 b = sigmoid_layer.b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
839
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
840 # add the layer to the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
841 self.sigmoid_layers.append(sigmoid_layer)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
842 self.daa_layers.append(daa_layer)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
843
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
844 # its arguably a philosophical question...
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
845 # but we are going to only declare that the parameters of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
846 # sigmoid_layers are parameters of the StackedDAA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
847 # the hidden-layer biases in the daa_layers are parameters of those
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
848 # daa_layers, but not the StackedDAA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
849 self.params.extend(sigmoid_layer.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
850
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
851 # We now need to add a logistic layer on top of the MLP
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
852 self.logistic_regressor = LogisticRegression(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
853 input = self.sigmoid_layers[-1].output,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
854 n_in = hidden_layers_sizes[-1],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
855 n_out = n_outs)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
856
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
857 self.params.extend(self.logLayer.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
858
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
859 def pretraining_functions(self, train_set_x, batch_size):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
860
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
861 # compiles update functions for each layer, and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
862 # returns them as a list
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
863 #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
864 # Construct a function that trains this dA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
865 # compute gradients of layer parameters
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
866 gparams = T.grad(dA_layer.cost, dA_layer.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
867 # compute the list of updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
868 updates = {}
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
869 for param, gparam in zip(dA_layer.params, gparams):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
870 updates[param] = param - gparam * pretrain_lr
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
871
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
872 # create a function that trains the dA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
873 update_fn = theano.function([index], dA_layer.cost, \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
874 updates = updates,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
875 givens = {
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
876 self.x : train_set_x[index*batch_size:(index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
877 # collect this function into a list
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
878 self.pretrain_functions += [update_fn]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
879
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
880