annotate code_tutoriel/deep.py @ 239:42005ec87747

Mergé (manuellement) les changements de Sylvain pour utiliser le code de dataset d'Arnaud, à cette différence près que je n'utilse pas les givens. J'ai probablement une approche différente pour limiter la taille du dataset dans mon débuggage, aussi.
author fsavard
date Mon, 15 Mar 2010 18:30:21 -0400
parents 4bc5eeec6394
children
rev   line source
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
1 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
2 Draft of DBN, DAA, SDAA, RBM tutorial code
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
3
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
4 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
5 import sys
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
6 import numpy
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
7 import theano
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
8 import time
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
9 import theano.tensor as T
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
10 from theano.tensor.shared_randomstreams import RandomStreams
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
11 from theano import shared, function
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
12
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
13 import gzip
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
14 import cPickle
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
15 import pylearn.io.image_tiling
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
16 import PIL
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
17
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
18 # NNET STUFF
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
19
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
20 class LogisticRegression(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
21 """Multi-class Logistic Regression Class
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
22
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
23 The logistic regression is fully described by a weight matrix :math:`W`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
24 and bias vector :math:`b`. Classification is done by projecting data
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
25 points onto a set of hyperplanes, the distance to which is used to
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
26 determine a class membership probability.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
27 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
28
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
29 def __init__(self, input, n_in, n_out):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
30 """ Initialize the parameters of the logistic regression
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
31 :param input: symbolic variable that describes the input of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
32 architecture (one minibatch)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
33 :type n_in: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
34 :param n_in: number of input units, the dimension of the space in
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
35 which the datapoints lie
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
36 :type n_out: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
37 :param n_out: number of output units, the dimension of the space in
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
38 which the labels lie
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
39 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
40
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
41 # initialize with 0 the weights W as a matrix of shape (n_in, n_out)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
42 self.W = theano.shared( value=numpy.zeros((n_in,n_out),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
43 dtype = theano.config.floatX) )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
44 # initialize the baises b as a vector of n_out 0s
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
45 self.b = theano.shared( value=numpy.zeros((n_out,),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
46 dtype = theano.config.floatX) )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
47 # compute vector of class-membership probabilities in symbolic form
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
48 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W)+self.b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
49
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
50 # compute prediction as class whose probability is maximal in
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
51 # symbolic form
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
52 self.y_pred=T.argmax(self.p_y_given_x, axis=1)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
53
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
54 # list of parameters for this layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
55 self.params = [self.W, self.b]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
56
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
57 def negative_log_likelihood(self, y):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
58 """Return the mean of the negative log-likelihood of the prediction
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
59 of this model under a given target distribution.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
60 :param y: corresponds to a vector that gives for each example the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
61 correct label
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
62 Note: we use the mean instead of the sum so that
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
63 the learning rate is less dependent on the batch size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
64 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
65 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
66
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
67 def errors(self, y):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
68 """Return a float representing the number of errors in the minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
69 over the total number of examples of the minibatch ; zero one
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
70 loss over the size of the minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
71 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
72 # check if y has same dimension of y_pred
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
73 if y.ndim != self.y_pred.ndim:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
74 raise TypeError('y should have the same shape as self.y_pred',
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
75 ('y', target.type, 'y_pred', self.y_pred.type))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
76
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
77 # check if y is of the correct datatype
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
78 if y.dtype.startswith('int'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
79 # the T.neq operator returns a vector of 0s and 1s, where 1
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
80 # represents a mistake in prediction
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
81 return T.mean(T.neq(self.y_pred, y))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
82 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
83 raise NotImplementedError()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
84
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
85 class SigmoidalLayer(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
86 def __init__(self, rng, input, n_in, n_out):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
87 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
88 Typical hidden layer of a MLP: units are fully-connected and have
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
89 sigmoidal activation function. Weight matrix W is of shape (n_in,n_out)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
90 and the bias vector b is of shape (n_out,).
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
91
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
92 Hidden unit activation is given by: sigmoid(dot(input,W) + b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
93
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
94 :type rng: numpy.random.RandomState
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
95 :param rng: a random number generator used to initialize weights
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
96 :type input: theano.tensor.matrix
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
97 :param input: a symbolic tensor of shape (n_examples, n_in)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
98 :type n_in: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
99 :param n_in: dimensionality of input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
100 :type n_out: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
101 :param n_out: number of hidden units
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
102 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
103 self.input = input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
104
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
105 W_values = numpy.asarray( rng.uniform( \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
106 low = -numpy.sqrt(6./(n_in+n_out)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
107 high = numpy.sqrt(6./(n_in+n_out)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
108 size = (n_in, n_out)), dtype = theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
109 self.W = theano.shared(value = W_values)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
110
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
111 b_values = numpy.zeros((n_out,), dtype= theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
112 self.b = theano.shared(value= b_values)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
113
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
114 self.output = T.nnet.sigmoid(T.dot(input, self.W) + self.b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
115 self.params = [self.W, self.b]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
116
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
117 # PRETRAINING LAYERS
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
118
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
119 class RBM(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
120 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
121 *** WRITE THE ENERGY FUNCTION USE SAME LETTERS AS VARIABLE NAMES IN CODE
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
122 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
123
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
124 def __init__(self, input=None, n_visible=None, n_hidden=None,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
125 W=None, hbias=None, vbias=None,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
126 numpy_rng=None, theano_rng=None):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
127 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
128 RBM constructor. Defines the parameters of the model along with
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
129 basic operations for inferring hidden from visible (and vice-versa),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
130 as well as for performing CD updates.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
131
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
132 :param input: None for standalone RBMs or symbolic variable if RBM is
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
133 part of a larger graph.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
134
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
135 :param n_visible: number of visible units (necessary when W or vbias is None)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
136
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
137 :param n_hidden: number of hidden units (necessary when W or hbias is None)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
138
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
139 :param W: weights to use for the RBM. None means that a shared variable will be
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
140 created with a randomly chosen matrix of size (n_visible, n_hidden).
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
141
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
142 :param hbias: ***
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
143
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
144 :param vbias: ***
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
145
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
146 :param numpy_rng: random number generator (necessary when W is None)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
147
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
148 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
149
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
150 params = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
151 if W is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
152 # choose initial values for weight matrix of RBM
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
153 initial_W = numpy.asarray(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
154 numpy_rng.uniform( \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
155 low=-numpy.sqrt(6./(n_hidden+n_visible)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
156 high=numpy.sqrt(6./(n_hidden+n_visible)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
157 size=(n_visible, n_hidden)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
158 dtype=theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
159 W = theano.shared(value=initial_W, name='W')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
160 params.append(W)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
161
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
162 if hbias is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
163 # theano shared variables for hidden biases
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
164 hbias = theano.shared(value=numpy.zeros(n_hidden,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
165 dtype=theano.config.floatX), name='hbias')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
166 params.append(hbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
167
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
168 if vbias is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
169 # theano shared variables for visible biases
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
170 vbias = theano.shared(value=numpy.zeros(n_visible,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
171 dtype=theano.config.floatX), name='vbias')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
172 params.append(vbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
173
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
174 if input is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
175 # initialize input layer for standalone RBM or layer0 of DBN
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
176 input = T.matrix('input')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
177
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
178 # setup theano random number generator
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
179 if theano_rng is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
180 theano_rng = RandomStreams(numpy_rng.randint(2**30))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
181
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
182 self.visible = self.input = input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
183 self.W = W
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
184 self.hbias = hbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
185 self.vbias = vbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
186 self.theano_rng = theano_rng
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
187 self.params = params
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
188 self.hidden_mean = T.nnet.sigmoid(T.dot(input, W)+hbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
189 self.hidden_sample = theano_rng.binomial(self.hidden_mean.shape, 1, self.hidden_mean)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
190
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
191 def gibbs_k(self, v_sample, k):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
192 ''' This function implements k steps of Gibbs sampling '''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
193
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
194 # We compute the visible after k steps of Gibbs by iterating
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
195 # over ``gibs_1`` for k times; this can be done in Theano using
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
196 # the `scan op`. For a more comprehensive description of scan see
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
197 # http://deeplearning.net/software/theano/library/scan.html .
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
198
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
199 def gibbs_1(v0_sample, W, hbias, vbias):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
200 ''' This function implements one Gibbs step '''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
201
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
202 # compute the activation of the hidden units given a sample of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
203 # vissibles
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
204 h0_mean = T.nnet.sigmoid(T.dot(v0_sample, W) + hbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
205 # get a sample of the hiddens given their activation
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
206 h0_sample = self.theano_rng.binomial(h0_mean.shape, 1, h0_mean)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
207 # compute the activation of the visible given the hidden sample
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
208 v1_mean = T.nnet.sigmoid(T.dot(h0_sample, W.T) + vbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
209 # get a sample of the visible given their activation
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
210 v1_act = self.theano_rng.binomial(v1_mean.shape, 1, v1_mean)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
211 return [v1_mean, v1_act]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
212
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
213
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
214 # DEBUGGING TO DO ALL WITHOUT SCAN
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
215 if k == 1:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
216 return gibbs_1(v_sample, self.W, self.hbias, self.vbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
217
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
218
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
219 # Because we require as output two values, namely the mean field
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
220 # approximation of the visible and the sample obtained after k steps,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
221 # scan needs to know the shape of those two outputs. Scan takes
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
222 # this information from the variables containing the initial state
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
223 # of the outputs. Since we do not need a initial state of ``v_mean``
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
224 # we provide a dummy one used only to get the correct shape
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
225 v_mean = T.zeros_like(v_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
226
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
227 # ``outputs_taps`` is an argument of scan which describes at each
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
228 # time step what past values of the outputs the function applied
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
229 # recursively needs. This is given in the form of a dictionary,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
230 # where the keys are outputs indexes, and values are a list of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
231 # of the offsets used by the corresponding outputs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
232 # In our case the function ``gibbs_1`` applied recursively, requires
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
233 # at time k the past value k-1 for the first output (index 0) and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
234 # no past value of the second output
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
235 outputs_taps = { 0 : [-1], 1 : [] }
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
236
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
237 v_means, v_samples = theano.scan( fn = gibbs_1,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
238 sequences = [],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
239 initial_states = [v_sample, v_mean],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
240 non_sequences = [self.W, self.hbias, self.vbias],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
241 outputs_taps = outputs_taps,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
242 n_steps = k)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
243 return v_means[-1], v_samples[-1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
244
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
245 def free_energy(self, v_sample):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
246 wx_b = T.dot(v_sample, self.W) + self.hbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
247 vbias_term = T.sum(T.dot(v_sample, self.vbias))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
248 hidden_term = T.sum(T.log(1+T.exp(wx_b)))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
249 return -hidden_term - vbias_term
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
250
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
251 def cd(self, visible = None, persistent = None, steps = 1):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
252 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
253 Return a 5-tuple of values related to contrastive divergence: (cost,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
254 end-state of negative-phase chain, gradient on weights, gradient on
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
255 hidden bias, gradient on visible bias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
256
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
257 If visible is None, it defaults to self.input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
258 If persistent is None, it defaults to self.input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
259
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
260 CD aka CD1 - cd()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
261 CD-10 - cd(steps=10)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
262 PCD - cd(persistent=shared(numpy.asarray(initializer)))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
263 PCD-k - cd(persistent=shared(numpy.asarray(initializer)),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
264 steps=10)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
265 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
266 if visible is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
267 visible = self.input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
268
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
269 if visible is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
270 raise TypeError('visible argument is required when self.input is None')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
271
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
272 if steps is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
273 steps = self.gibbs_1
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
274
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
275 if persistent is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
276 chain_start = visible
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
277 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
278 chain_start = persistent
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
279
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
280 chain_end_mean, chain_end_sample = self.gibbs_k(chain_start, steps)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
281
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
282 #print >> sys.stderr, "WARNING: DEBUGGING with wrong FREE ENERGY"
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
283 #free_energy_delta = - self.free_energy(chain_end_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
284 free_energy_delta = self.free_energy(visible) - self.free_energy(chain_end_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
285
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
286 # we will return all of these regardless of what is in self.params
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
287 all_params = [self.W, self.hbias, self.vbias]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
288
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
289 gparams = T.grad(free_energy_delta, all_params,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
290 consider_constant = [chain_end_sample])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
291
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
292 cross_entropy = T.mean(T.sum(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
293 visible*T.log(chain_end_mean) + (1 - visible)*T.log(1-chain_end_mean),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
294 axis = 1))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
295
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
296 return (cross_entropy, chain_end_sample,) + tuple(gparams)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
297
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
298 def cd_updates(self, lr, visible = None, persistent = None, steps = 1):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
299 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
300 Return the learning updates for the RBM parameters that are shared variables.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
301
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
302 Also returns an update for the persistent if it is a shared variable.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
303
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
304 These updates are returned as a dictionary.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
305
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
306 :param lr: [scalar] learning rate for contrastive divergence learning
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
307 :param visible: see `cd_grad`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
308 :param persistent: see `cd_grad`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
309 :param steps: see `cd_grad`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
310
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
311 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
312
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
313 cross_entropy, chain_end, gW, ghbias, gvbias = self.cd(visible,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
314 persistent, steps)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
315
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
316 updates = {}
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
317 if hasattr(self.W, 'value'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
318 updates[self.W] = self.W - lr * gW
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
319 if hasattr(self.hbias, 'value'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
320 updates[self.hbias] = self.hbias - lr * ghbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
321 if hasattr(self.vbias, 'value'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
322 updates[self.vbias] = self.vbias - lr * gvbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
323 if persistent:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
324 #if persistent is a shared var, then it means we should use
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
325 updates[persistent] = chain_end
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
326
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
327 return updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
328
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
329 # DEEP MODELS
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
330
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
331 class DBN(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
332 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
333 *** WHAT IS A DBN?
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
334 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
335
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
336 def __init__(self, input_len, hidden_layers_sizes, n_classes, rng):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
337 """ This class is made to support a variable number of layers.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
338
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
339 :param train_set_x: symbolic variable pointing to the training dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
340
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
341 :param train_set_y: symbolic variable pointing to the labels of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
342 training dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
343
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
344 :param input_len: dimension of the input to the sdA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
345
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
346 :param n_layers_sizes: intermidiate layers size, must contain
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
347 at least one value
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
348
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
349 :param n_classes: dimension of the output of the network
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
350
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
351 :param corruption_levels: amount of corruption to use for each
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
352 layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
353
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
354 :param rng: numpy random number generator used to draw initial weights
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
355
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
356 :param pretrain_lr: learning rate used during pre-trainnig stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
357
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
358 :param finetune_lr: learning rate used during finetune stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
359 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
360
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
361 self.sigmoid_layers = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
362 self.rbm_layers = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
363 self.pretrain_functions = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
364 self.params = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
365
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
366 theano_rng = RandomStreams(rng.randint(2**30))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
367
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
368 # allocate symbolic variables for the data
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
369 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
370 self.x = T.matrix('x') # the data is presented as rasterized images
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
371 self.y = T.ivector('y') # the labels are presented as 1D vector of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
372 # [int] labels
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
373 input = self.x
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
374
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
375 # The SdA is an MLP, for which all weights of intermidiate layers
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
376 # are shared with a different denoising autoencoders
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
377 # We will first construct the SdA as a deep multilayer perceptron,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
378 # and when constructing each sigmoidal layer we also construct a
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
379 # denoising autoencoder that shares weights with that layer, and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
380 # compile a training function for that denoising autoencoder
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
381
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
382 for n_hid in hidden_layers_sizes:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
383 # construct the sigmoidal layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
384
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
385 sigmoid_layer = SigmoidalLayer(rng, input, input_len, n_hid)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
386 self.sigmoid_layers.append(sigmoid_layer)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
387
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
388 self.rbm_layers.append(RBM(input=input,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
389 W=sigmoid_layer.W,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
390 hbias=sigmoid_layer.b,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
391 n_visible = input_len,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
392 n_hidden = n_hid,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
393 numpy_rng=rng,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
394 theano_rng=theano_rng))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
395
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
396 # its arguably a philosophical question...
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
397 # but we are going to only declare that the parameters of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
398 # sigmoid_layers are parameters of the StackedDAA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
399 # the hidden-layer biases in the daa_layers are parameters of those
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
400 # daa_layers, but not the StackedDAA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
401 self.params.extend(self.sigmoid_layers[-1].params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
402
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
403 # get ready for the next loop iteration
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
404 input_len = n_hid
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
405 input = self.sigmoid_layers[-1].output
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
406
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
407 # We now need to add a logistic layer on top of the MLP
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
408 self.logistic_regressor = LogisticRegression(input = input,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
409 n_in = input_len, n_out = n_classes)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
410
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
411 self.params.extend(self.logistic_regressor.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
412
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
413 def pretraining_functions(self, train_set_x, batch_size, learning_rate, k=1):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
414 if k!=1:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
415 raise NotImplementedError()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
416 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
417 n_train_batches = train_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
418 batch_begin = (index % n_train_batches) * batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
419 batch_end = batch_begin+batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
420
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
421 print 'TRAIN_SET X', train_set_x.value.shape
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
422 rval = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
423 for rbm in self.rbm_layers:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
424 # N.B. these cd() samples are independent from the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
425 # samples used for learning
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
426 outputs = list(rbm.cd())[0:2]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
427 rval.append(function([index], outputs,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
428 updates = rbm.cd_updates(lr=learning_rate),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
429 givens = {self.x: train_set_x[batch_begin:batch_end]}))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
430 if rbm is self.rbm_layers[0]:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
431 f = rval[-1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
432 AA=len(outputs)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
433 for i, implicit_out in enumerate(f.maker.env.outputs): #[len(outputs):]:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
434 print 'OUTPUT ', i
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
435 theano.printing.debugprint(implicit_out, file=sys.stdout)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
436
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
437 return rval
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
438
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
439 def finetune(self, datasets, lr, batch_size):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
440
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
441 # unpack the various datasets
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
442 (train_set_x, train_set_y) = datasets[0]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
443 (valid_set_x, valid_set_y) = datasets[1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
444 (test_set_x, test_set_y) = datasets[2]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
445
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
446 # compute number of minibatches for training, validation and testing
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
447 assert train_set_x.value.shape[0] % batch_size == 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
448 assert valid_set_x.value.shape[0] % batch_size == 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
449 assert test_set_x.value.shape[0] % batch_size == 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
450 n_train_batches = train_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
451 n_valid_batches = valid_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
452 n_test_batches = test_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
453
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
454 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
455 target = self.y
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
456
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
457 train_index = index % n_train_batches
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
458
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
459 classifier = self.logistic_regressor
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
460 cost = classifier.negative_log_likelihood(target)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
461 # compute the gradients with respect to the model parameters
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
462 gparams = T.grad(cost, self.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
463
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
464 # compute list of fine-tuning updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
465 updates = [(param, param - gparam*finetune_lr)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
466 for param,gparam in zip(self.params, gparams)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
467
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
468 train_fn = theano.function([index], cost,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
469 updates = updates,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
470 givens = {
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
471 self.x : train_set_x[train_index*batch_size:(train_index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
472 target : train_set_y[train_index*batch_size:(train_index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
473
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
474 test_score_i = theano.function([index], classifier.errors(target),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
475 givens = {
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
476 self.x: test_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
477 target: test_set_y[index*batch_size:(index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
478
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
479 valid_score_i = theano.function([index], classifier.errors(target),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
480 givens = {
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
481 self.x: valid_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
482 target: valid_set_y[index*batch_size:(index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
483
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
484 def test_scores():
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
485 return [test_score_i(i) for i in xrange(n_test_batches)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
486
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
487 def valid_scores():
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
488 return [valid_score_i(i) for i in xrange(n_valid_batches)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
489
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
490 return train_fn, valid_scores, test_scores
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
491
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
492 def load_mnist(filename):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
493 f = gzip.open(filename,'rb')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
494 train_set, valid_set, test_set = cPickle.load(f)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
495 f.close()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
496
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
497 def shared_dataset(data_xy):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
498 data_x, data_y = data_xy
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
499 shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
500 shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
501 return shared_x, T.cast(shared_y, 'int32')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
502
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
503 n_train_examples = train_set[0].shape[0]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
504 datasets = shared_dataset(train_set), shared_dataset(valid_set), shared_dataset(test_set)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
505
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
506 return n_train_examples, datasets
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
507
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
508 def dbn_main(finetune_lr = 0.01,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
509 pretraining_epochs = 10,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
510 pretrain_lr = 0.1,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
511 training_epochs = 1000,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
512 batch_size = 20,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
513 mnist_file='mnist.pkl.gz'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
514 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
515 Demonstrate stochastic gradient descent optimization for a multilayer perceptron
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
516
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
517 This is demonstrated on MNIST.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
518
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
519 :param learning_rate: learning rate used in the finetune stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
520 (factor for the stochastic gradient)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
521
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
522 :param pretraining_epochs: number of epoch to do pretraining
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
523
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
524 :param pretrain_lr: learning rate to be used during pre-training
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
525
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
526 :param n_iter: maximal number of iterations ot run the optimizer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
527
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
528 :param mnist_file: path the the pickled mnist_file
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
529
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
530 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
531
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
532 n_train_examples, train_valid_test = load_mnist(mnist_file)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
533
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
534 print "Creating a Deep Belief Network"
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
535 deep_model = DBN(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
536 input_len=28*28,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
537 hidden_layers_sizes = [500, 150, 100],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
538 n_classes=10,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
539 rng = numpy.random.RandomState())
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
540
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
541 ####
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
542 #### Phase 1: Pre-training
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
543 ####
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
544 print "Pretraining (unsupervised learning) ..."
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
545
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
546 pretrain_functions = deep_model.pretraining_functions(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
547 batch_size=batch_size,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
548 train_set_x=train_valid_test[0][0],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
549 learning_rate=pretrain_lr,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
550 )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
551
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
552 start_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
553 for layer_idx, pretrain_fn in enumerate(pretrain_functions):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
554 # go through pretraining epochs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
555 print 'Pre-training layer %i'% layer_idx
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
556 for i in xrange(pretraining_epochs * n_train_examples / batch_size):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
557 outstuff = pretrain_fn(i)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
558 xe, negsample = outstuff[:2]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
559 print (layer_idx, i,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
560 n_train_examples / batch_size,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
561 float(xe),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
562 'Wmin', deep_model.rbm_layers[0].W.value.min(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
563 'Wmax', deep_model.rbm_layers[0].W.value.max(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
564 'vmin', deep_model.rbm_layers[0].vbias.value.min(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
565 'vmax', deep_model.rbm_layers[0].vbias.value.max(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
566 #'x>0.3', (input_i>0.3).sum(),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
567 )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
568 sys.stdout.flush()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
569 if i % 1000 == 0:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
570 PIL.Image.fromarray(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
571 pylearn.io.image_tiling.tile_raster_images(negsample, (28,28), (10,10),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
572 tile_spacing=(1,1))).save('samples_%i_%i.png'%(layer_idx,i))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
573
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
574 PIL.Image.fromarray(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
575 pylearn.io.image_tiling.tile_raster_images(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
576 deep_model.rbm_layers[0].W.value.T,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
577 (28,28), (10,10),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
578 tile_spacing=(1,1))).save('filters_%i_%i.png'%(layer_idx,i))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
579 end_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
580 print 'Pretraining took %f minutes' %((end_time - start_time)/60.)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
581
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
582 return
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
583
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
584 print "Fine tuning (supervised learning) ..."
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
585 train_fn, valid_scores, test_scores =\
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
586 deep_model.finetune_functions(train_valid_test[0][0],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
587 learning_rate=finetune_lr, # the learning rate
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
588 batch_size = batch_size) # number of examples to use at once
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
589
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
590 ####
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
591 #### Phase 2: Fine Tuning
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
592 ####
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
593
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
594 patience = 10000 # look as this many examples regardless
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
595 patience_increase = 2. # wait this much longer when a new best is
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
596 # found
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
597 improvement_threshold = 0.995 # a relative improvement of this much is
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
598 # considered significant
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
599 validation_frequency = min(n_train_examples, patience/2)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
600 # go through this many
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
601 # minibatche before checking the network
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
602 # on the validation set; in this case we
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
603 # check every epoch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
604
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
605 patience_max = n_train_examples * training_epochs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
606
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
607 best_epoch = None
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
608 best_epoch_test_score = None
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
609 best_epoch_valid_score = float('inf')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
610 start_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
611
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
612 for i in xrange(patience_max):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
613 if i >= patience:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
614 break
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
615
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
616 cost_i = train_fn(i)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
617
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
618 if i % validation_frequency == 0:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
619 validation_i = numpy.mean([score for score in valid_scores()])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
620
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
621 # if we got the best validation score until now
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
622 if validation_i < best_epoch_valid_score:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
623
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
624 # improve patience if loss improvement is good enough
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
625 threshold_i = best_epoch_valid_score * improvement_threshold
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
626 if validation_i < threshold_i:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
627 patience = max(patience, i * patience_increase)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
628
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
629 # save best validation score and iteration number
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
630 best_epoch_valid_score = validation_i
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
631 best_epoch = i/validation_i
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
632 best_epoch_test_score = numpy.mean(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
633 [score for score in test_scores()])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
634
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
635 print('epoch %i, validation error %f %%, test error %f %%'%(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
636 i/validation_frequency, validation_i*100.,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
637 best_epoch_test_score*100.))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
638 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
639 print('epoch %i, validation error %f %%' % (
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
640 i/validation_frequency, validation_i*100.))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
641 end_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
642
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
643 print(('Optimization complete with best validation score of %f %%,'
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
644 'with test performance %f %%') %
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
645 (finetune_status['best_validation_loss']*100.,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
646 finetune_status['test_score']*100.))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
647 print ('The code ran for %f minutes' % ((finetune_status['duration'])/60.))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
648
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
649 def rbm_main():
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
650 rbm = RBM(n_visible=20, n_hidden=30,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
651 numpy_rng = numpy.random.RandomState(34))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
652
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
653 cd_updates = rbm.cd_updates(lr=0.25)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
654
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
655 print cd_updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
656
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
657 f = function([rbm.input], [],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
658 updates={rbm.W:cd_updates[rbm.W]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
659
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
660 theano.printing.debugprint(f.maker.env.outputs[0],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
661 file=sys.stdout)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
662
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
663
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
664 if __name__ == '__main__':
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
665 dbn_main()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
666 #rbm_main()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
667
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
668
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
669 if 0:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
670 class DAA(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
671 def __init__(self, n_visible= 784, n_hidden= 500, corruption_level = 0.1,\
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
672 input = None, shared_W = None, shared_b = None):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
673 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
674 Initialize the dA class by specifying the number of visible units (the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
675 dimension d of the input ), the number of hidden units ( the dimension
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
676 d' of the latent or hidden space ) and the corruption level. The
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
677 constructor also receives symbolic variables for the input, weights and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
678 bias. Such a symbolic variables are useful when, for example the input is
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
679 the result of some computations, or when weights are shared between the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
680 dA and an MLP layer. When dealing with SdAs this always happens,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
681 the dA on layer 2 gets as input the output of the dA on layer 1,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
682 and the weights of the dA are used in the second stage of training
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
683 to construct an MLP.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
684
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
685 :param n_visible: number of visible units
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
686
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
687 :param n_hidden: number of hidden units
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
688
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
689 :param input: a symbolic description of the input or None
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
690
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
691 :param corruption_level: the corruption mechanism picks up randomly this
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
692 fraction of entries of the input and turns them to 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
693
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
694
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
695 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
696 self.n_visible = n_visible
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
697 self.n_hidden = n_hidden
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
698
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
699 # create a Theano random generator that gives symbolic random values
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
700 theano_rng = RandomStreams()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
701
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
702 if shared_W != None and shared_b != None :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
703 self.W = shared_W
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
704 self.b = shared_b
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
705 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
706 # initial values for weights and biases
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
707 # note : W' was written as `W_prime` and b' as `b_prime`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
708
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
709 # W is initialized with `initial_W` which is uniformely sampled
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
710 # from -6./sqrt(n_visible+n_hidden) and 6./sqrt(n_hidden+n_visible)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
711 # the output of uniform if converted using asarray to dtype
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
712 # theano.config.floatX so that the code is runable on GPU
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
713 initial_W = numpy.asarray( numpy.random.uniform( \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
714 low = -numpy.sqrt(6./(n_hidden+n_visible)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
715 high = numpy.sqrt(6./(n_hidden+n_visible)), \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
716 size = (n_visible, n_hidden)), dtype = theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
717 initial_b = numpy.zeros(n_hidden, dtype = theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
718
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
719
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
720 # theano shared variables for weights and biases
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
721 self.W = theano.shared(value = initial_W, name = "W")
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
722 self.b = theano.shared(value = initial_b, name = "b")
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
723
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
724
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
725 initial_b_prime= numpy.zeros(n_visible)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
726 # tied weights, therefore W_prime is W transpose
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
727 self.W_prime = self.W.T
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
728 self.b_prime = theano.shared(value = initial_b_prime, name = "b'")
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
729
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
730 # if no input is given, generate a variable representing the input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
731 if input == None :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
732 # we use a matrix because we expect a minibatch of several examples,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
733 # each example being a row
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
734 self.x = T.matrix(name = 'input')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
735 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
736 self.x = input
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
737 # Equation (1)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
738 # keep 90% of the inputs the same and zero-out randomly selected subset of 10% of the inputs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
739 # note : first argument of theano.rng.binomial is the shape(size) of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
740 # random numbers that it should produce
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
741 # second argument is the number of trials
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
742 # third argument is the probability of success of any trial
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
743 #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
744 # this will produce an array of 0s and 1s where 1 has a
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
745 # probability of 1 - ``corruption_level`` and 0 with
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
746 # ``corruption_level``
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
747 self.tilde_x = theano_rng.binomial( self.x.shape, 1, 1 - corruption_level) * self.x
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
748 # Equation (2)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
749 # note : y is stored as an attribute of the class so that it can be
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
750 # used later when stacking dAs.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
751 self.y = T.nnet.sigmoid(T.dot(self.tilde_x, self.W ) + self.b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
752 # Equation (3)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
753 self.z = T.nnet.sigmoid(T.dot(self.y, self.W_prime) + self.b_prime)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
754 # Equation (4)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
755 # note : we sum over the size of a datapoint; if we are using minibatches,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
756 # L will be a vector, with one entry per example in minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
757 self.L = - T.sum( self.x*T.log(self.z) + (1-self.x)*T.log(1-self.z), axis=1 )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
758 # note : L is now a vector, where each element is the cross-entropy cost
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
759 # of the reconstruction of the corresponding example of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
760 # minibatch. We need to compute the average of all these to get
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
761 # the cost of the minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
762 self.cost = T.mean(self.L)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
763
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
764 self.params = [ self.W, self.b, self.b_prime ]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
765
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
766 class StackedDAA(DeepLayerwiseModel):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
767 """Stacked denoising auto-encoder class (SdA)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
768
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
769 A stacked denoising autoencoder model is obtained by stacking several
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
770 dAs. The hidden layer of the dA at layer `i` becomes the input of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
771 the dA at layer `i+1`. The first layer dA gets as input the input of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
772 the SdA, and the hidden layer of the last dA represents the output.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
773 Note that after pretraining, the SdA is dealt with as a normal MLP,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
774 the dAs are only used to initialize the weights.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
775 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
776
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
777 def __init__(self, n_ins, hidden_layers_sizes, n_outs,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
778 corruption_levels, rng, ):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
779 """ This class is made to support a variable number of layers.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
780
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
781 :param train_set_x: symbolic variable pointing to the training dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
782
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
783 :param train_set_y: symbolic variable pointing to the labels of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
784 training dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
785
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
786 :param n_ins: dimension of the input to the sdA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
787
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
788 :param n_layers_sizes: intermidiate layers size, must contain
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
789 at least one value
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
790
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
791 :param n_outs: dimension of the output of the network
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
792
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
793 :param corruption_levels: amount of corruption to use for each
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
794 layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
795
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
796 :param rng: numpy random number generator used to draw initial weights
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
797
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
798 :param pretrain_lr: learning rate used during pre-trainnig stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
799
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
800 :param finetune_lr: learning rate used during finetune stage
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
801 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
802
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
803 self.sigmoid_layers = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
804 self.daa_layers = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
805 self.pretrain_functions = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
806 self.params = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
807 self.n_layers = len(hidden_layers_sizes)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
808
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
809 if len(hidden_layers_sizes) < 1 :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
810 raiseException (' You must have at least one hidden layer ')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
811
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
812 theano_rng = RandomStreams(rng.randint(2**30))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
813
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
814 # allocate symbolic variables for the data
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
815 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
816 self.x = T.matrix('x') # the data is presented as rasterized images
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
817 self.y = T.ivector('y') # the labels are presented as 1D vector of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
818 # [int] labels
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
819
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
820 # The SdA is an MLP, for which all weights of intermidiate layers
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
821 # are shared with a different denoising autoencoders
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
822 # We will first construct the SdA as a deep multilayer perceptron,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
823 # and when constructing each sigmoidal layer we also construct a
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
824 # denoising autoencoder that shares weights with that layer, and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
825 # compile a training function for that denoising autoencoder
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
826
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
827 for i in xrange( self.n_layers ):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
828 # construct the sigmoidal layer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
829
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
830 sigmoid_layer = SigmoidalLayer(rng,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
831 self.layers[-1].output if i else self.x,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
832 hidden_layers_sizes[i-1] if i else n_ins,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
833 hidden_layers_sizes[i])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
834
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
835 daa_layer = DAA(corruption_level = corruption_levels[i],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
836 input = sigmoid_layer.input,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
837 W = sigmoid_layer.W,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
838 b = sigmoid_layer.b)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
839
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
840 # add the layer to the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
841 self.sigmoid_layers.append(sigmoid_layer)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
842 self.daa_layers.append(daa_layer)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
843
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
844 # its arguably a philosophical question...
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
845 # but we are going to only declare that the parameters of the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
846 # sigmoid_layers are parameters of the StackedDAA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
847 # the hidden-layer biases in the daa_layers are parameters of those
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
848 # daa_layers, but not the StackedDAA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
849 self.params.extend(sigmoid_layer.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
850
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
851 # We now need to add a logistic layer on top of the MLP
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
852 self.logistic_regressor = LogisticRegression(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
853 input = self.sigmoid_layers[-1].output,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
854 n_in = hidden_layers_sizes[-1],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
855 n_out = n_outs)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
856
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
857 self.params.extend(self.logLayer.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
858
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
859 def pretraining_functions(self, train_set_x, batch_size):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
860
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
861 # compiles update functions for each layer, and
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
862 # returns them as a list
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
863 #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
864 # Construct a function that trains this dA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
865 # compute gradients of layer parameters
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
866 gparams = T.grad(dA_layer.cost, dA_layer.params)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
867 # compute the list of updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
868 updates = {}
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
869 for param, gparam in zip(dA_layer.params, gparams):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
870 updates[param] = param - gparam * pretrain_lr
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
871
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
872 # create a function that trains the dA
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
873 update_fn = theano.function([index], dA_layer.cost, \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
874 updates = updates,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
875 givens = {
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
876 self.x : train_set_x[index*batch_size:(index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
877 # collect this function into a list
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
878 self.pretrain_functions += [update_fn]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
879
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
880