annotate code_tutoriel/logistic_sgd.py @ 0:fda5f787baa6

commit initial
author Dumitru Erhan <dumitru.erhan@gmail.com>
date Thu, 21 Jan 2010 11:26:43 -0500
parents
children bcc87d3e33a3
rev   line source
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
1 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
2 This tutorial introduces logistic regression using Theano and stochastic
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
3 gradient descent.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
4
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
5 Logistic regression is a probabilistic, linear classifier. It is parametrized
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
6 by a weight matrix :math:`W` and a bias vector :math:`b`. Classification is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
7 done by projecting data points onto a set of hyperplanes, the distance to
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
8 which is used to determine a class membership probability.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
9
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
10 Mathematically, this can be written as:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
11
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
12 .. math::
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
13 P(Y=i|x, W,b) &= softmax_i(W x + b) \\
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
14 &= \frac {e^{W_i x + b_i}} {\sum_j e^{W_j x + b_j}}
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
15
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
16
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
17 The output of the model or prediction is then done by taking the argmax of
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
18 the vector whose i'th element is P(Y=i|x).
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
19
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
20 .. math::
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
21
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
22 y_{pred} = argmax_i P(Y=i|x,W,b)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
23
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
24
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
25 This tutorial presents a stochastic gradient descent optimization method
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
26 suitable for large datasets, and a conjugate gradient optimization method
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
27 that is suitable for smaller datasets.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
28
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
29
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
30 References:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
31
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
32 - textbooks: "Pattern Recognition and Machine Learning" -
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
33 Christopher M. Bishop, section 4.3.2
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
34
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
35
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
36 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
37 __docformat__ = 'restructedtext en'
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
38
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
39
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
40 import numpy, cPickle, gzip
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
41
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
42 import time
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
43
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
44 import theano
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
45 import theano.tensor as T
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
46
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
47 import theano.tensor.nnet
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
48
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
49
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
50 class LogisticRegression(object):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
51 """Multi-class Logistic Regression Class
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
52
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
53 The logistic regression is fully described by a weight matrix :math:`W`
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
54 and bias vector :math:`b`. Classification is done by projecting data
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
55 points onto a set of hyperplanes, the distance to which is used to
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
56 determine a class membership probability.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
57 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
58
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
59
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
60
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
61
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
62 def __init__(self, input, n_in, n_out):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
63 """ Initialize the parameters of the logistic regression
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
64
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
65 :param input: symbolic variable that describes the input of the
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
66 architecture (one minibatch)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
67
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
68 :param n_in: number of input units, the dimension of the space in
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
69 which the datapoints lie
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
70
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
71 :param n_out: number of output units, the dimension of the space in
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
72 which the labels lie
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
73
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
74 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
75
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
76 # initialize with 0 the weights W as a matrix of shape (n_in, n_out)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
77 self.W = theano.shared( value=numpy.zeros((n_in,n_out),
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
78 dtype = theano.config.floatX) )
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
79 # initialize the baises b as a vector of n_out 0s
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
80 self.b = theano.shared( value=numpy.zeros((n_out,),
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
81 dtype = theano.config.floatX) )
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
82
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
83
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
84 # compute vector of class-membership probabilities in symbolic form
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
85 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W)+self.b)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
86
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
87 # compute prediction as class whose probability is maximal in
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
88 # symbolic form
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
89 self.y_pred=T.argmax(self.p_y_given_x, axis=1)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
90
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
91
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
92
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
93
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
94
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
95 def negative_log_likelihood(self, y):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
96 """Return the mean of the negative log-likelihood of the prediction
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
97 of this model under a given target distribution.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
98
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
99 .. math::
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
100
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
101 \frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) =
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
102 \frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
103 \ell (\theta=\{W,b\}, \mathcal{D})
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
104
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
105
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
106 :param y: corresponds to a vector that gives for each example the
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
107 :correct label
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
108
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
109 Note: we use the mean instead of the sum so that
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
110 the learning rate is less dependent on the batch size
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
111 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
112 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y])
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
113
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
114
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
115
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
116
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
117
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
118 def errors(self, y):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
119 """Return a float representing the number of errors in the minibatch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
120 over the total number of examples of the minibatch ; zero one
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
121 loss over the size of the minibatch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
122 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
123
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
124 # check if y has same dimension of y_pred
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
125 if y.ndim != self.y_pred.ndim:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
126 raise TypeError('y should have the same shape as self.y_pred',
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
127 ('y', target.type, 'y_pred', self.y_pred.type))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
128 # check if y is of the correct datatype
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
129 if y.dtype.startswith('int'):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
130 # the T.neq operator returns a vector of 0s and 1s, where 1
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
131 # represents a mistake in prediction
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
132 return T.mean(T.neq(self.y_pred, y))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
133 else:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
134 raise NotImplementedError()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
135
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
136
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
137
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
138
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
139
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
140 def sgd_optimization_mnist( learning_rate=0.01, n_iter=100):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
141 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
142 Demonstrate stochastic gradient descent optimization of a log-linear
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
143 model
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
144
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
145 This is demonstrated on MNIST.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
146
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
147 :param learning_rate: learning rate used (factor for the stochastic
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
148 gradient
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
149
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
150 :param n_iter: number of iterations ot run the optimizer
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
151
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
152 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
153
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
154 # Load the dataset
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
155 f = gzip.open('mnist.pkl.gz','rb')
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
156 train_set, valid_set, test_set = cPickle.load(f)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
157 f.close()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
158
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
159 # make minibatches of size 20
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
160 batch_size = 20 # sized of the minibatch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
161
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
162 # Dealing with the training set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
163 # get the list of training images (x) and their labels (y)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
164 (train_set_x, train_set_y) = train_set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
165 # initialize the list of training minibatches with empty list
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
166 train_batches = []
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
167 for i in xrange(0, len(train_set_x), batch_size):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
168 # add to the list of minibatches the minibatch starting at
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
169 # position i, ending at position i+batch_size
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
170 # a minibatch is a pair ; the first element of the pair is a list
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
171 # of datapoints, the second element is the list of corresponding
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
172 # labels
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
173 train_batches = train_batches + \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
174 [(train_set_x[i:i+batch_size], train_set_y[i:i+batch_size])]
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
175
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
176 # Dealing with the validation set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
177 (valid_set_x, valid_set_y) = valid_set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
178 # initialize the list of validation minibatches
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
179 valid_batches = []
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
180 for i in xrange(0, len(valid_set_x), batch_size):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
181 valid_batches = valid_batches + \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
182 [(valid_set_x[i:i+batch_size], valid_set_y[i:i+batch_size])]
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
183
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
184 # Dealing with the testing set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
185 (test_set_x, test_set_y) = test_set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
186 # initialize the list of testing minibatches
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
187 test_batches = []
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
188 for i in xrange(0, len(test_set_x), batch_size):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
189 test_batches = test_batches + \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
190 [(test_set_x[i:i+batch_size], test_set_y[i:i+batch_size])]
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
191
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
192
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
193 ishape = (28,28) # this is the size of MNIST images
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
194
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
195 # allocate symbolic variables for the data
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
196 x = T.fmatrix() # the data is presented as rasterized images
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
197 y = T.lvector() # the labels are presented as 1D vector of
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
198 # [long int] labels
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
199
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
200 # construct the logistic regression class
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
201 classifier = LogisticRegression( \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
202 input=x.reshape((batch_size,28*28)), n_in=28*28, n_out=10)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
203
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
204 # the cost we minimize during training is the negative log likelihood of
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
205 # the model in symbolic format
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
206 cost = classifier.negative_log_likelihood(y)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
207
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
208 # compiling a Theano function that computes the mistakes that are made by
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
209 # the model on a minibatch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
210 test_model = theano.function([x,y], classifier.errors(y))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
211
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
212 # compute the gradient of cost with respect to theta = (W,b)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
213 g_W = T.grad(cost, classifier.W)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
214 g_b = T.grad(cost, classifier.b)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
215
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
216 # specify how to update the parameters of the model as a dictionary
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
217 updates ={classifier.W: classifier.W - learning_rate*g_W,\
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
218 classifier.b: classifier.b - learning_rate*g_b}
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
219
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
220 # compiling a Theano function `train_model` that returns the cost, but in
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
221 # the same time updates the parameter of the model based on the rules
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
222 # defined in `updates`
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
223 train_model = theano.function([x, y], cost, updates = updates )
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
224
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
225 n_minibatches = len(train_batches) # number of minibatchers
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
226
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
227 # early-stopping parameters
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
228 patience = 5000 # look as this many examples regardless
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
229 patience_increase = 2 # wait this much longer when a new best is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
230 # found
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
231 improvement_threshold = 0.995 # a relative improvement of this much is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
232 # considered significant
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
233 validation_frequency = n_minibatches # go through this many
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
234 # minibatche before checking the network
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
235 # on the validation set; in this case we
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
236 # check every epoch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
237
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
238 best_params = None
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
239 best_validation_loss = float('inf')
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
240 test_score = 0.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
241 start_time = time.clock()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
242 # have a maximum of `n_iter` iterations through the entire dataset
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
243 for iter in xrange(n_iter* n_minibatches):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
244
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
245 # get epoch and minibatch index
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
246 epoch = iter / n_minibatches
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
247 minibatch_index = iter % n_minibatches
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
248
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
249 # get the minibatches corresponding to `iter` modulo
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
250 # `len(train_batches)`
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
251 x,y = train_batches[ minibatch_index ]
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
252 cost_ij = train_model(x,y)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
253
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
254 if (iter+1) % validation_frequency == 0:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
255 # compute zero-one loss on validation set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
256 this_validation_loss = 0.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
257 for x,y in valid_batches:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
258 # sum up the errors for each minibatch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
259 this_validation_loss += test_model(x,y)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
260 # get the average by dividing with the number of minibatches
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
261 this_validation_loss /= len(valid_batches)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
262
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
263 print('epoch %i, minibatch %i/%i, validation error %f %%' % \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
264 (epoch, minibatch_index+1,n_minibatches, \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
265 this_validation_loss*100.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
266
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
267
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
268 # if we got the best validation score until now
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
269 if this_validation_loss < best_validation_loss:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
270 #improve patience if loss improvement is good enough
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
271 if this_validation_loss < best_validation_loss * \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
272 improvement_threshold :
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
273 patience = max(patience, iter * patience_increase)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
274
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
275 best_validation_loss = this_validation_loss
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
276 # test it on the test set
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
277
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
278 test_score = 0.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
279 for x,y in test_batches:
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
280 test_score += test_model(x,y)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
281 test_score /= len(test_batches)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
282 print((' epoch %i, minibatch %i/%i, test error of best '
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
283 'model %f %%') % \
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
284 (epoch, minibatch_index+1, n_minibatches,test_score*100.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
285
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
286 if patience <= iter :
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
287 break
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
288
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
289 end_time = time.clock()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
290 print(('Optimization complete with best validation score of %f %%,'
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
291 'with test performance %f %%') %
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
292 (best_validation_loss * 100., test_score*100.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
293 print ('The code ran for %f minutes' % ((end_time-start_time)/60.))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
294
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
295
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
296
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
297
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
298
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
299
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
300
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
301 if __name__ == '__main__':
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
302 sgd_optimization_mnist()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
303