annotate code_tutoriel/logistic_sgd.py @ 239:42005ec87747

Mergé (manuellement) les changements de Sylvain pour utiliser le code de dataset d'Arnaud, à cette différence près que je n'utilse pas les givens. J'ai probablement une approche différente pour limiter la taille du dataset dans mon débuggage, aussi.
author fsavard
date Mon, 15 Mar 2010 18:30:21 -0400
parents 4bc5eeec6394
children
rev   line source
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
1 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
2 This tutorial introduces logistic regression using Theano and stochastic
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
3 gradient descent.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
4
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
5 Logistic regression is a probabilistic, linear classifier. It is parametrized
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
6 by a weight matrix :math:`W` and a bias vector :math:`b`. Classification is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
7 done by projecting data points onto a set of hyperplanes, the distance to
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
8 which is used to determine a class membership probability.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
9
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
10 Mathematically, this can be written as:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
11
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
12 .. math::
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
13 P(Y=i|x, W,b) &= softmax_i(W x + b) \\
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
14 &= \frac {e^{W_i x + b_i}} {\sum_j e^{W_j x + b_j}}
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
15
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
16
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
17 The output of the model or prediction is then done by taking the argmax of
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
18 the vector whose i'th element is P(Y=i|x).
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
19
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
20 .. math::
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
21
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
22 y_{pred} = argmax_i P(Y=i|x,W,b)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
23
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
24
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
25 This tutorial presents a stochastic gradient descent optimization method
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
26 suitable for large datasets, and a conjugate gradient optimization method
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
27 that is suitable for smaller datasets.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
28
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
29
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
30 References:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
31
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
32 - textbooks: "Pattern Recognition and Machine Learning" -
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
33 Christopher M. Bishop, section 4.3.2
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
34
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
35 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
36 __docformat__ = 'restructedtext en'
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
37
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
38 import numpy, time, cPickle, gzip
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
39
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
40 import theano
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
41 import theano.tensor as T
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
42
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
43
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
44 class LogisticRegression(object):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
45 """Multi-class Logistic Regression Class
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
46
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
47 The logistic regression is fully described by a weight matrix :math:`W`
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
48 and bias vector :math:`b`. Classification is done by projecting data
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
49 points onto a set of hyperplanes, the distance to which is used to
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
50 determine a class membership probability.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
51 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
52
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
53
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
54
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
55
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
56 def __init__(self, input, n_in, n_out):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
57 """ Initialize the parameters of the logistic regression
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
58
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
59 :type input: theano.tensor.TensorType
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
60 :param input: symbolic variable that describes the input of the
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
61 architecture (one minibatch)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
62
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
63 :type n_in: int
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
64 :param n_in: number of input units, the dimension of the space in
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
65 which the datapoints lie
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
66
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
67 :type n_out: int
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
68 :param n_out: number of output units, the dimension of the space in
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
69 which the labels lie
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
70
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
71 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
72
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
73 # initialize with 0 the weights W as a matrix of shape (n_in, n_out)
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
74 self.W = theano.shared(value=numpy.zeros((n_in,n_out), dtype = theano.config.floatX),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
75 name='W')
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
76 # initialize the baises b as a vector of n_out 0s
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
77 self.b = theano.shared(value=numpy.zeros((n_out,), dtype = theano.config.floatX),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
78 name='b')
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
79
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
80
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
81 # compute vector of class-membership probabilities in symbolic form
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
82 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W)+self.b)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
83
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
84 # compute prediction as class whose probability is maximal in
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
85 # symbolic form
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
86 self.y_pred=T.argmax(self.p_y_given_x, axis=1)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
87
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
88 # parameters of the model
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
89 self.params = [self.W, self.b]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
90
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
91
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
92
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
93
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
94
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
95 def negative_log_likelihood(self, y):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
96 """Return the mean of the negative log-likelihood of the prediction
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
97 of this model under a given target distribution.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
98
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
99 .. math::
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
100
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
101 \frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) =
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
102 \frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
103 \ell (\theta=\{W,b\}, \mathcal{D})
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
104
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
105 :type y: theano.tensor.TensorType
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
106 :param y: corresponds to a vector that gives for each example the
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
107 correct label
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
108
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
109 Note: we use the mean instead of the sum so that
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
110 the learning rate is less dependent on the batch size
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
111 """
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
112 # y.shape[0] is (symbolically) the number of rows in y, i.e., number of examples (call it n) in the minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
113 # T.arange(y.shape[0]) is a symbolic vector which will contain [0,1,2,... n-1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
114 # T.log(self.p_y_given_x) is a matrix of Log-Probabilities (call it LP) with one row per example and one column per class
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
115 # LP[T.arange(y.shape[0]),y] is a vector v containing [LP[0,y[0]], LP[1,y[1]], LP[2,y[2]], ..., LP[n-1,y[n-1]]]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
116 # and T.mean(LP[T.arange(y.shape[0]),y]) is the mean (across minibatch examples) of the elements in v,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
117 # i.e., the mean log-likelihood across the minibatch.
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
118 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y])
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
119
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
120
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
121 def errors(self, y):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
122 """Return a float representing the number of errors in the minibatch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
123 over the total number of examples of the minibatch ; zero one
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
124 loss over the size of the minibatch
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
125
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
126 :type y: theano.tensor.TensorType
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
127 :param y: corresponds to a vector that gives for each example the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
128 correct label
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
129 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
130
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
131 # check if y has same dimension of y_pred
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
132 if y.ndim != self.y_pred.ndim:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
133 raise TypeError('y should have the same shape as self.y_pred',
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
134 ('y', target.type, 'y_pred', self.y_pred.type))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
135 # check if y is of the correct datatype
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
136 if y.dtype.startswith('int'):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
137 # the T.neq operator returns a vector of 0s and 1s, where 1
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
138 # represents a mistake in prediction
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
139 return T.mean(T.neq(self.y_pred, y))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
140 else:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
141 raise NotImplementedError()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
142
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
143
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
144 def load_data(dataset):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
145 ''' Loads the dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
146
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
147 :type dataset: string
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
148 :param dataset: the path to the dataset (here MNIST)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
149 '''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
150
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
151 #############
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
152 # LOAD DATA #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
153 #############
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
154 print '... loading data'
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
155
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
156 # Load the dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
157 f = gzip.open(dataset,'rb')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
158 train_set, valid_set, test_set = cPickle.load(f)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
159 f.close()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
160
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
161
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
162 def shared_dataset(data_xy):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
163 """ Function that loads the dataset into shared variables
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
164
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
165 The reason we store our dataset in shared variables is to allow
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
166 Theano to copy it into the GPU memory (when code is run on GPU).
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
167 Since copying data into the GPU is slow, copying a minibatch everytime
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
168 is needed (the default behaviour if the data is not in a shared
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
169 variable) would lead to a large decrease in performance.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
170 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
171 data_x, data_y = data_xy
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
172 shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
173 shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
174 # When storing data on the GPU it has to be stored as floats
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
175 # therefore we will store the labels as ``floatX`` as well
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
176 # (``shared_y`` does exactly that). But during our computations
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
177 # we need them as ints (we use labels as index, and if they are
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
178 # floats it doesn't make sense) therefore instead of returning
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
179 # ``shared_y`` we will have to cast it to int. This little hack
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
180 # lets ous get around this issue
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
181 return shared_x, T.cast(shared_y, 'int32')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
182
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
183 test_set_x, test_set_y = shared_dataset(test_set)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
184 valid_set_x, valid_set_y = shared_dataset(valid_set)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
185 train_set_x, train_set_y = shared_dataset(train_set)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
186
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
187 rval = [(train_set_x, train_set_y), (valid_set_x,valid_set_y), (test_set_x, test_set_y)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
188 return rval
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
189
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
190
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
191
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
192
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
193 def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000, dataset='mnist.pkl.gz'):
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
194 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
195 Demonstrate stochastic gradient descent optimization of a log-linear
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
196 model
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
197
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
198 This is demonstrated on MNIST.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
199
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
200 :type learning_rate: float
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
201 :param learning_rate: learning rate used (factor for the stochastic
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
202 gradient)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
203
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
204 :type n_epochs: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
205 :param n_epochs: maximal number of epochs to run the optimizer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
206
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
207 :type dataset: string
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
208 :param dataset: the path of the MNIST dataset file from
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
209 http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
210
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
211 """
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
212 datasets = load_data(dataset)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
213
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
214 train_set_x, train_set_y = datasets[0]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
215 valid_set_x, valid_set_y = datasets[1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
216 test_set_x , test_set_y = datasets[2]
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
217
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
218 batch_size = 600 # size of the minibatch
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
219
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
220 # compute number of minibatches for training, validation and testing
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
221 n_train_batches = train_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
222 n_valid_batches = valid_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
223 n_test_batches = test_set_x.value.shape[0] / batch_size
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
224
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
225
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
226 ######################
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
227 # BUILD ACTUAL MODEL #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
228 ######################
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
229 print '... building the model'
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
230
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
231
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
232 # allocate symbolic variables for the data
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
233 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
234 x = T.matrix('x') # the data is presented as rasterized images
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
235 y = T.ivector('y') # the labels are presented as 1D vector of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
236 # [int] labels
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
237
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
238 # construct the logistic regression class
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
239 # Each MNIST image has size 28*28
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
240 classifier = LogisticRegression( input=x, n_in=28*28, n_out=10)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
241
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
242 # the cost we minimize during training is the negative log likelihood of
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
243 # the model in symbolic format
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
244 cost = classifier.negative_log_likelihood(y)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
245
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
246 # compiling a Theano function that computes the mistakes that are made by
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
247 # the model on a minibatch
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
248 test_model = theano.function(inputs = [index],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
249 outputs = classifier.errors(y),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
250 givens={
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
251 x:test_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
252 y:test_set_y[index*batch_size:(index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
253
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
254 validate_model = theano.function( inputs = [index],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
255 outputs = classifier.errors(y),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
256 givens={
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
257 x:valid_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
258 y:valid_set_y[index*batch_size:(index+1)*batch_size]})
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
259
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
260 # compute the gradient of cost with respect to theta = (W,b)
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
261 g_W = T.grad(cost = cost, wrt = classifier.W)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
262 g_b = T.grad(cost = cost, wrt = classifier.b)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
263
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
264 # specify how to update the parameters of the model as a dictionary
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
265 updates ={classifier.W: classifier.W - learning_rate*g_W,\
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
266 classifier.b: classifier.b - learning_rate*g_b}
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
267
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
268 # compiling a Theano function `train_model` that returns the cost, but in
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
269 # the same time updates the parameter of the model based on the rules
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
270 # defined in `updates`
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
271 train_model = theano.function(inputs = [index],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
272 outputs = cost,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
273 updates = updates,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
274 givens={
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
275 x:train_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
276 y:train_set_y[index*batch_size:(index+1)*batch_size]})
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
277
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
278 ###############
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
279 # TRAIN MODEL #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
280 ###############
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
281 print '... training the model'
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
282 # early-stopping parameters
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
283 patience = 5000 # look as this many examples regardless
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
284 patience_increase = 2 # wait this much longer when a new best is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
285 # found
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
286 improvement_threshold = 0.995 # a relative improvement of this much is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
287 # considered significant
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
288 validation_frequency = min(n_train_batches, patience/2)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
289 # go through this many
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
290 # minibatche before checking the network
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
291 # on the validation set; in this case we
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
292 # check every epoch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
293
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
294 best_params = None
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
295 best_validation_loss = float('inf')
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
296 test_score = 0.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
297 start_time = time.clock()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
298
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
299 done_looping = False
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
300 epoch = 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
301 while (epoch < n_epochs) and (not done_looping):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
302 epoch = epoch + 1
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
303 for minibatch_index in xrange(n_train_batches):
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
304
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
305 minibatch_avg_cost = train_model(minibatch_index)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
306 # iteration number
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
307 iter = epoch * n_train_batches + minibatch_index
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
308
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
309 if (iter+1) % validation_frequency == 0:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
310 # compute zero-one loss on validation set
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
311 validation_losses = [validate_model(i) for i in xrange(n_valid_batches)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
312 this_validation_loss = numpy.mean(validation_losses)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
313
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
314 print('epoch %i, minibatch %i/%i, validation error %f %%' % \
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
315 (epoch, minibatch_index+1,n_train_batches, \
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
316 this_validation_loss*100.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
317
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
318
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
319 # if we got the best validation score until now
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
320 if this_validation_loss < best_validation_loss:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
321 #improve patience if loss improvement is good enough
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
322 if this_validation_loss < best_validation_loss * \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
323 improvement_threshold :
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
324 patience = max(patience, iter * patience_increase)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
325
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
326 best_validation_loss = this_validation_loss
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
327 # test it on the test set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
328
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
329 test_losses = [test_model(i) for i in xrange(n_test_batches)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
330 test_score = numpy.mean(test_losses)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
331
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
332 print((' epoch %i, minibatch %i/%i, test error of best '
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
333 'model %f %%') % \
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
334 (epoch, minibatch_index+1, n_train_batches,test_score*100.))
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
335
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
336 if patience <= iter :
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
337 done_looping = True
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
338 break
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
339
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
340 end_time = time.clock()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
341 print(('Optimization complete with best validation score of %f %%,'
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
342 'with test performance %f %%') %
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
343 (best_validation_loss * 100., test_score*100.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
344 print ('The code ran for %f minutes' % ((end_time-start_time)/60.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
345
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
346 if __name__ == '__main__':
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
347 sgd_optimization_mnist()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
348