annotate code_tutoriel/logistic_sgd.py @ 356:b0741ea3ff6f

Extension du choix de la classe principale pour les batches d'entrainement
author Guillaume Sicard <guitch21@gmail.com>
date Wed, 21 Apr 2010 23:47:50 -0400
parents 4bc5eeec6394
children
rev   line source
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
1 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
2 This tutorial introduces logistic regression using Theano and stochastic
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
3 gradient descent.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
4
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
5 Logistic regression is a probabilistic, linear classifier. It is parametrized
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
6 by a weight matrix :math:`W` and a bias vector :math:`b`. Classification is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
7 done by projecting data points onto a set of hyperplanes, the distance to
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
8 which is used to determine a class membership probability.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
9
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
10 Mathematically, this can be written as:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
11
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
12 .. math::
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
13 P(Y=i|x, W,b) &= softmax_i(W x + b) \\
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
14 &= \frac {e^{W_i x + b_i}} {\sum_j e^{W_j x + b_j}}
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
15
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
16
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
17 The output of the model or prediction is then done by taking the argmax of
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
18 the vector whose i'th element is P(Y=i|x).
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
19
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
20 .. math::
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
21
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
22 y_{pred} = argmax_i P(Y=i|x,W,b)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
23
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
24
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
25 This tutorial presents a stochastic gradient descent optimization method
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
26 suitable for large datasets, and a conjugate gradient optimization method
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
27 that is suitable for smaller datasets.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
28
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
29
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
30 References:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
31
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
32 - textbooks: "Pattern Recognition and Machine Learning" -
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
33 Christopher M. Bishop, section 4.3.2
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
34
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
35 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
36 __docformat__ = 'restructedtext en'
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
37
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
38 import numpy, time, cPickle, gzip
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
39
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
40 import theano
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
41 import theano.tensor as T
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
42
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
43
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
44 class LogisticRegression(object):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
45 """Multi-class Logistic Regression Class
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
46
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
47 The logistic regression is fully described by a weight matrix :math:`W`
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
48 and bias vector :math:`b`. Classification is done by projecting data
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
49 points onto a set of hyperplanes, the distance to which is used to
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
50 determine a class membership probability.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
51 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
52
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
53
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
54
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
55
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
56 def __init__(self, input, n_in, n_out):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
57 """ Initialize the parameters of the logistic regression
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
58
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
59 :type input: theano.tensor.TensorType
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
60 :param input: symbolic variable that describes the input of the
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
61 architecture (one minibatch)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
62
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
63 :type n_in: int
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
64 :param n_in: number of input units, the dimension of the space in
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
65 which the datapoints lie
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
66
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
67 :type n_out: int
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
68 :param n_out: number of output units, the dimension of the space in
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
69 which the labels lie
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
70
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
71 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
72
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
73 # initialize with 0 the weights W as a matrix of shape (n_in, n_out)
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
74 self.W = theano.shared(value=numpy.zeros((n_in,n_out), dtype = theano.config.floatX),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
75 name='W')
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
76 # initialize the baises b as a vector of n_out 0s
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
77 self.b = theano.shared(value=numpy.zeros((n_out,), dtype = theano.config.floatX),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
78 name='b')
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
79
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
80
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
81 # compute vector of class-membership probabilities in symbolic form
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
82 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W)+self.b)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
83
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
84 # compute prediction as class whose probability is maximal in
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
85 # symbolic form
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
86 self.y_pred=T.argmax(self.p_y_given_x, axis=1)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
87
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
88 # parameters of the model
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
89 self.params = [self.W, self.b]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
90
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
91
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
92
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
93
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
94
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
95 def negative_log_likelihood(self, y):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
96 """Return the mean of the negative log-likelihood of the prediction
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
97 of this model under a given target distribution.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
98
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
99 .. math::
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
100
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
101 \frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) =
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
102 \frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
103 \ell (\theta=\{W,b\}, \mathcal{D})
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
104
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
105 :type y: theano.tensor.TensorType
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
106 :param y: corresponds to a vector that gives for each example the
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
107 correct label
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
108
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
109 Note: we use the mean instead of the sum so that
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
110 the learning rate is less dependent on the batch size
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
111 """
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
112 # y.shape[0] is (symbolically) the number of rows in y, i.e., number of examples (call it n) in the minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
113 # T.arange(y.shape[0]) is a symbolic vector which will contain [0,1,2,... n-1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
114 # T.log(self.p_y_given_x) is a matrix of Log-Probabilities (call it LP) with one row per example and one column per class
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
115 # LP[T.arange(y.shape[0]),y] is a vector v containing [LP[0,y[0]], LP[1,y[1]], LP[2,y[2]], ..., LP[n-1,y[n-1]]]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
116 # and T.mean(LP[T.arange(y.shape[0]),y]) is the mean (across minibatch examples) of the elements in v,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
117 # i.e., the mean log-likelihood across the minibatch.
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
118 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y])
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
119
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
120
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
121 def errors(self, y):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
122 """Return a float representing the number of errors in the minibatch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
123 over the total number of examples of the minibatch ; zero one
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
124 loss over the size of the minibatch
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
125
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
126 :type y: theano.tensor.TensorType
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
127 :param y: corresponds to a vector that gives for each example the
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
128 correct label
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
129 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
130
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
131 # check if y has same dimension of y_pred
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
132 if y.ndim != self.y_pred.ndim:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
133 raise TypeError('y should have the same shape as self.y_pred',
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
134 ('y', target.type, 'y_pred', self.y_pred.type))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
135 # check if y is of the correct datatype
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
136 if y.dtype.startswith('int'):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
137 # the T.neq operator returns a vector of 0s and 1s, where 1
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
138 # represents a mistake in prediction
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
139 return T.mean(T.neq(self.y_pred, y))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
140 else:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
141 raise NotImplementedError()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
142
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
143
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
144 def load_data(dataset):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
145 ''' Loads the dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
146
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
147 :type dataset: string
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
148 :param dataset: the path to the dataset (here MNIST)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
149 '''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
150
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
151 #############
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
152 # LOAD DATA #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
153 #############
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
154 print '... loading data'
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
155
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
156 # Load the dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
157 f = gzip.open(dataset,'rb')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
158 train_set, valid_set, test_set = cPickle.load(f)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
159 f.close()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
160
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
161
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
162 def shared_dataset(data_xy):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
163 """ Function that loads the dataset into shared variables
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
164
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
165 The reason we store our dataset in shared variables is to allow
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
166 Theano to copy it into the GPU memory (when code is run on GPU).
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
167 Since copying data into the GPU is slow, copying a minibatch everytime
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
168 is needed (the default behaviour if the data is not in a shared
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
169 variable) would lead to a large decrease in performance.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
170 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
171 data_x, data_y = data_xy
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
172 shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
173 shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
174 # When storing data on the GPU it has to be stored as floats
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
175 # therefore we will store the labels as ``floatX`` as well
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
176 # (``shared_y`` does exactly that). But during our computations
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
177 # we need them as ints (we use labels as index, and if they are
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
178 # floats it doesn't make sense) therefore instead of returning
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
179 # ``shared_y`` we will have to cast it to int. This little hack
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
180 # lets ous get around this issue
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
181 return shared_x, T.cast(shared_y, 'int32')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
182
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
183 test_set_x, test_set_y = shared_dataset(test_set)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
184 valid_set_x, valid_set_y = shared_dataset(valid_set)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
185 train_set_x, train_set_y = shared_dataset(train_set)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
186
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
187 rval = [(train_set_x, train_set_y), (valid_set_x,valid_set_y), (test_set_x, test_set_y)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
188 return rval
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
189
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
190
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
191
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
192
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
193 def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000, dataset='mnist.pkl.gz'):
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
194 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
195 Demonstrate stochastic gradient descent optimization of a log-linear
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
196 model
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
197
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
198 This is demonstrated on MNIST.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
199
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
200 :type learning_rate: float
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
201 :param learning_rate: learning rate used (factor for the stochastic
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
202 gradient)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
203
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
204 :type n_epochs: int
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
205 :param n_epochs: maximal number of epochs to run the optimizer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
206
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
207 :type dataset: string
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
208 :param dataset: the path of the MNIST dataset file from
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
209 http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
210
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
211 """
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
212 datasets = load_data(dataset)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
213
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
214 train_set_x, train_set_y = datasets[0]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
215 valid_set_x, valid_set_y = datasets[1]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
216 test_set_x , test_set_y = datasets[2]
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
217
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
218 batch_size = 600 # size of the minibatch
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
219
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
220 # compute number of minibatches for training, validation and testing
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
221 n_train_batches = train_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
222 n_valid_batches = valid_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
223 n_test_batches = test_set_x.value.shape[0] / batch_size
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
224
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
225
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
226 ######################
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
227 # BUILD ACTUAL MODEL #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
228 ######################
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
229 print '... building the model'
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
230
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
231
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
232 # allocate symbolic variables for the data
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
233 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
234 x = T.matrix('x') # the data is presented as rasterized images
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
235 y = T.ivector('y') # the labels are presented as 1D vector of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
236 # [int] labels
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
237
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
238 # construct the logistic regression class
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
239 # Each MNIST image has size 28*28
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
240 classifier = LogisticRegression( input=x, n_in=28*28, n_out=10)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
241
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
242 # the cost we minimize during training is the negative log likelihood of
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
243 # the model in symbolic format
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
244 cost = classifier.negative_log_likelihood(y)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
245
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
246 # compiling a Theano function that computes the mistakes that are made by
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
247 # the model on a minibatch
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
248 test_model = theano.function(inputs = [index],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
249 outputs = classifier.errors(y),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
250 givens={
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
251 x:test_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
252 y:test_set_y[index*batch_size:(index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
253
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
254 validate_model = theano.function( inputs = [index],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
255 outputs = classifier.errors(y),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
256 givens={
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
257 x:valid_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
258 y:valid_set_y[index*batch_size:(index+1)*batch_size]})
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
259
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
260 # compute the gradient of cost with respect to theta = (W,b)
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
261 g_W = T.grad(cost = cost, wrt = classifier.W)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
262 g_b = T.grad(cost = cost, wrt = classifier.b)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
263
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
264 # specify how to update the parameters of the model as a dictionary
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
265 updates ={classifier.W: classifier.W - learning_rate*g_W,\
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
266 classifier.b: classifier.b - learning_rate*g_b}
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
267
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
268 # compiling a Theano function `train_model` that returns the cost, but in
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
269 # the same time updates the parameter of the model based on the rules
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
270 # defined in `updates`
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
271 train_model = theano.function(inputs = [index],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
272 outputs = cost,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
273 updates = updates,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
274 givens={
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
275 x:train_set_x[index*batch_size:(index+1)*batch_size],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
276 y:train_set_y[index*batch_size:(index+1)*batch_size]})
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
277
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
278 ###############
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
279 # TRAIN MODEL #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
280 ###############
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
281 print '... training the model'
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
282 # early-stopping parameters
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
283 patience = 5000 # look as this many examples regardless
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
284 patience_increase = 2 # wait this much longer when a new best is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
285 # found
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
286 improvement_threshold = 0.995 # a relative improvement of this much is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
287 # considered significant
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
288 validation_frequency = min(n_train_batches, patience/2)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
289 # go through this many
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
290 # minibatche before checking the network
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
291 # on the validation set; in this case we
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
292 # check every epoch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
293
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
294 best_params = None
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
295 best_validation_loss = float('inf')
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
296 test_score = 0.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
297 start_time = time.clock()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
298
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
299 done_looping = False
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
300 epoch = 0
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
301 while (epoch < n_epochs) and (not done_looping):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
302 epoch = epoch + 1
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
303 for minibatch_index in xrange(n_train_batches):
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
304
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
305 minibatch_avg_cost = train_model(minibatch_index)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
306 # iteration number
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
307 iter = epoch * n_train_batches + minibatch_index
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
308
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
309 if (iter+1) % validation_frequency == 0:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
310 # compute zero-one loss on validation set
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
311 validation_losses = [validate_model(i) for i in xrange(n_valid_batches)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
312 this_validation_loss = numpy.mean(validation_losses)
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
313
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
314 print('epoch %i, minibatch %i/%i, validation error %f %%' % \
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
315 (epoch, minibatch_index+1,n_train_batches, \
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
316 this_validation_loss*100.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
317
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
318
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
319 # if we got the best validation score until now
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
320 if this_validation_loss < best_validation_loss:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
321 #improve patience if loss improvement is good enough
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
322 if this_validation_loss < best_validation_loss * \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
323 improvement_threshold :
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
324 patience = max(patience, iter * patience_increase)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
325
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
326 best_validation_loss = this_validation_loss
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
327 # test it on the test set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
328
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
329 test_losses = [test_model(i) for i in xrange(n_test_batches)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
330 test_score = numpy.mean(test_losses)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
331
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
332 print((' epoch %i, minibatch %i/%i, test error of best '
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
333 'model %f %%') % \
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
334 (epoch, minibatch_index+1, n_train_batches,test_score*100.))
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
335
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
336 if patience <= iter :
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents: 2
diff changeset
337 done_looping = True
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
338 break
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
339
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
340 end_time = time.clock()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
341 print(('Optimization complete with best validation score of %f %%,'
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
342 'with test performance %f %%') %
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
343 (best_validation_loss * 100., test_score*100.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
344 print ('The code ran for %f minutes' % ((end_time-start_time)/60.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
345
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
346 if __name__ == '__main__':
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
347 sgd_optimization_mnist()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
348