Mercurial > ift6266
annotate code_tutoriel/logistic_sgd.py @ 487:21787ac4e5a0
-
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 31 May 2010 22:04:44 -0400 |
parents | 4bc5eeec6394 |
children |
rev | line source |
---|---|
0 | 1 """ |
2 This tutorial introduces logistic regression using Theano and stochastic | |
3 gradient descent. | |
4 | |
5 Logistic regression is a probabilistic, linear classifier. It is parametrized | |
6 by a weight matrix :math:`W` and a bias vector :math:`b`. Classification is | |
7 done by projecting data points onto a set of hyperplanes, the distance to | |
8 which is used to determine a class membership probability. | |
9 | |
10 Mathematically, this can be written as: | |
11 | |
12 .. math:: | |
13 P(Y=i|x, W,b) &= softmax_i(W x + b) \\ | |
14 &= \frac {e^{W_i x + b_i}} {\sum_j e^{W_j x + b_j}} | |
15 | |
16 | |
17 The output of the model or prediction is then done by taking the argmax of | |
18 the vector whose i'th element is P(Y=i|x). | |
19 | |
20 .. math:: | |
21 | |
22 y_{pred} = argmax_i P(Y=i|x,W,b) | |
23 | |
24 | |
25 This tutorial presents a stochastic gradient descent optimization method | |
26 suitable for large datasets, and a conjugate gradient optimization method | |
27 that is suitable for smaller datasets. | |
28 | |
29 | |
30 References: | |
31 | |
32 - textbooks: "Pattern Recognition and Machine Learning" - | |
33 Christopher M. Bishop, section 4.3.2 | |
34 | |
35 """ | |
36 __docformat__ = 'restructedtext en' | |
37 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
38 import numpy, time, cPickle, gzip |
0 | 39 |
40 import theano | |
41 import theano.tensor as T | |
42 | |
43 | |
44 class LogisticRegression(object): | |
45 """Multi-class Logistic Regression Class | |
46 | |
47 The logistic regression is fully described by a weight matrix :math:`W` | |
48 and bias vector :math:`b`. Classification is done by projecting data | |
49 points onto a set of hyperplanes, the distance to which is used to | |
50 determine a class membership probability. | |
51 """ | |
52 | |
53 | |
54 | |
55 | |
56 def __init__(self, input, n_in, n_out): | |
57 """ Initialize the parameters of the logistic regression | |
58 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
59 :type input: theano.tensor.TensorType |
0 | 60 :param input: symbolic variable that describes the input of the |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
61 architecture (one minibatch) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
62 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
63 :type n_in: int |
0 | 64 :param n_in: number of input units, the dimension of the space in |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
65 which the datapoints lie |
0 | 66 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
67 :type n_out: int |
0 | 68 :param n_out: number of output units, the dimension of the space in |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
69 which the labels lie |
0 | 70 |
71 """ | |
72 | |
73 # initialize with 0 the weights W as a matrix of shape (n_in, n_out) | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
74 self.W = theano.shared(value=numpy.zeros((n_in,n_out), dtype = theano.config.floatX), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
75 name='W') |
0 | 76 # initialize the baises b as a vector of n_out 0s |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
77 self.b = theano.shared(value=numpy.zeros((n_out,), dtype = theano.config.floatX), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
78 name='b') |
0 | 79 |
80 | |
81 # compute vector of class-membership probabilities in symbolic form | |
82 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W)+self.b) | |
83 | |
84 # compute prediction as class whose probability is maximal in | |
85 # symbolic form | |
86 self.y_pred=T.argmax(self.p_y_given_x, axis=1) | |
87 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
88 # parameters of the model |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
89 self.params = [self.W, self.b] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
90 |
0 | 91 |
92 | |
93 | |
94 | |
95 def negative_log_likelihood(self, y): | |
96 """Return the mean of the negative log-likelihood of the prediction | |
97 of this model under a given target distribution. | |
98 | |
99 .. math:: | |
100 | |
101 \frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) = | |
102 \frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\ | |
103 \ell (\theta=\{W,b\}, \mathcal{D}) | |
104 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
105 :type y: theano.tensor.TensorType |
0 | 106 :param y: corresponds to a vector that gives for each example the |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
107 correct label |
0 | 108 |
109 Note: we use the mean instead of the sum so that | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
110 the learning rate is less dependent on the batch size |
0 | 111 """ |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
112 # y.shape[0] is (symbolically) the number of rows in y, i.e., number of examples (call it n) in the minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
113 # T.arange(y.shape[0]) is a symbolic vector which will contain [0,1,2,... n-1] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
114 # T.log(self.p_y_given_x) is a matrix of Log-Probabilities (call it LP) with one row per example and one column per class |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
115 # LP[T.arange(y.shape[0]),y] is a vector v containing [LP[0,y[0]], LP[1,y[1]], LP[2,y[2]], ..., LP[n-1,y[n-1]]] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
116 # and T.mean(LP[T.arange(y.shape[0]),y]) is the mean (across minibatch examples) of the elements in v, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
117 # i.e., the mean log-likelihood across the minibatch. |
0 | 118 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y]) |
119 | |
120 | |
121 def errors(self, y): | |
122 """Return a float representing the number of errors in the minibatch | |
123 over the total number of examples of the minibatch ; zero one | |
124 loss over the size of the minibatch | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
125 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
126 :type y: theano.tensor.TensorType |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
127 :param y: corresponds to a vector that gives for each example the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
128 correct label |
0 | 129 """ |
130 | |
131 # check if y has same dimension of y_pred | |
132 if y.ndim != self.y_pred.ndim: | |
133 raise TypeError('y should have the same shape as self.y_pred', | |
134 ('y', target.type, 'y_pred', self.y_pred.type)) | |
135 # check if y is of the correct datatype | |
136 if y.dtype.startswith('int'): | |
137 # the T.neq operator returns a vector of 0s and 1s, where 1 | |
138 # represents a mistake in prediction | |
139 return T.mean(T.neq(self.y_pred, y)) | |
140 else: | |
141 raise NotImplementedError() | |
142 | |
143 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
144 def load_data(dataset): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
145 ''' Loads the dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
146 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
147 :type dataset: string |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
148 :param dataset: the path to the dataset (here MNIST) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
149 ''' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
150 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
151 ############# |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
152 # LOAD DATA # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
153 ############# |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
154 print '... loading data' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
155 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
156 # Load the dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
157 f = gzip.open(dataset,'rb') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
158 train_set, valid_set, test_set = cPickle.load(f) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
159 f.close() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
160 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
161 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
162 def shared_dataset(data_xy): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
163 """ Function that loads the dataset into shared variables |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
164 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
165 The reason we store our dataset in shared variables is to allow |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
166 Theano to copy it into the GPU memory (when code is run on GPU). |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
167 Since copying data into the GPU is slow, copying a minibatch everytime |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
168 is needed (the default behaviour if the data is not in a shared |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
169 variable) would lead to a large decrease in performance. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
170 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
171 data_x, data_y = data_xy |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
172 shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
173 shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
174 # When storing data on the GPU it has to be stored as floats |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
175 # therefore we will store the labels as ``floatX`` as well |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
176 # (``shared_y`` does exactly that). But during our computations |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
177 # we need them as ints (we use labels as index, and if they are |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
178 # floats it doesn't make sense) therefore instead of returning |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
179 # ``shared_y`` we will have to cast it to int. This little hack |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
180 # lets ous get around this issue |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
181 return shared_x, T.cast(shared_y, 'int32') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
182 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
183 test_set_x, test_set_y = shared_dataset(test_set) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
184 valid_set_x, valid_set_y = shared_dataset(valid_set) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
185 train_set_x, train_set_y = shared_dataset(train_set) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
186 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
187 rval = [(train_set_x, train_set_y), (valid_set_x,valid_set_y), (test_set_x, test_set_y)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
188 return rval |
0 | 189 |
190 | |
191 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
192 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
193 def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000, dataset='mnist.pkl.gz'): |
0 | 194 """ |
195 Demonstrate stochastic gradient descent optimization of a log-linear | |
196 model | |
197 | |
198 This is demonstrated on MNIST. | |
199 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
200 :type learning_rate: float |
0 | 201 :param learning_rate: learning rate used (factor for the stochastic |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
202 gradient) |
0 | 203 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
204 :type n_epochs: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
205 :param n_epochs: maximal number of epochs to run the optimizer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
206 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
207 :type dataset: string |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
208 :param dataset: the path of the MNIST dataset file from |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
209 http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz |
0 | 210 |
211 """ | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
212 datasets = load_data(dataset) |
0 | 213 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
214 train_set_x, train_set_y = datasets[0] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
215 valid_set_x, valid_set_y = datasets[1] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
216 test_set_x , test_set_y = datasets[2] |
0 | 217 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
218 batch_size = 600 # size of the minibatch |
0 | 219 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
220 # compute number of minibatches for training, validation and testing |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
221 n_train_batches = train_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
222 n_valid_batches = valid_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
223 n_test_batches = test_set_x.value.shape[0] / batch_size |
0 | 224 |
225 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
226 ###################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
227 # BUILD ACTUAL MODEL # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
228 ###################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
229 print '... building the model' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
230 |
0 | 231 |
232 # allocate symbolic variables for the data | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
233 index = T.lscalar() # index to a [mini]batch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
234 x = T.matrix('x') # the data is presented as rasterized images |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
235 y = T.ivector('y') # the labels are presented as 1D vector of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
236 # [int] labels |
0 | 237 |
238 # construct the logistic regression class | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
239 # Each MNIST image has size 28*28 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
240 classifier = LogisticRegression( input=x, n_in=28*28, n_out=10) |
0 | 241 |
242 # the cost we minimize during training is the negative log likelihood of | |
243 # the model in symbolic format | |
244 cost = classifier.negative_log_likelihood(y) | |
245 | |
246 # compiling a Theano function that computes the mistakes that are made by | |
247 # the model on a minibatch | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
248 test_model = theano.function(inputs = [index], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
249 outputs = classifier.errors(y), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
250 givens={ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
251 x:test_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
252 y:test_set_y[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
253 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
254 validate_model = theano.function( inputs = [index], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
255 outputs = classifier.errors(y), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
256 givens={ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
257 x:valid_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
258 y:valid_set_y[index*batch_size:(index+1)*batch_size]}) |
0 | 259 |
260 # compute the gradient of cost with respect to theta = (W,b) | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
261 g_W = T.grad(cost = cost, wrt = classifier.W) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
262 g_b = T.grad(cost = cost, wrt = classifier.b) |
0 | 263 |
264 # specify how to update the parameters of the model as a dictionary | |
265 updates ={classifier.W: classifier.W - learning_rate*g_W,\ | |
266 classifier.b: classifier.b - learning_rate*g_b} | |
267 | |
268 # compiling a Theano function `train_model` that returns the cost, but in | |
269 # the same time updates the parameter of the model based on the rules | |
270 # defined in `updates` | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
271 train_model = theano.function(inputs = [index], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
272 outputs = cost, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
273 updates = updates, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
274 givens={ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
275 x:train_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
276 y:train_set_y[index*batch_size:(index+1)*batch_size]}) |
0 | 277 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
278 ############### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
279 # TRAIN MODEL # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
280 ############### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
281 print '... training the model' |
0 | 282 # early-stopping parameters |
283 patience = 5000 # look as this many examples regardless | |
284 patience_increase = 2 # wait this much longer when a new best is | |
285 # found | |
286 improvement_threshold = 0.995 # a relative improvement of this much is | |
287 # considered significant | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
288 validation_frequency = min(n_train_batches, patience/2) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
289 # go through this many |
0 | 290 # minibatche before checking the network |
291 # on the validation set; in this case we | |
292 # check every epoch | |
293 | |
294 best_params = None | |
295 best_validation_loss = float('inf') | |
296 test_score = 0. | |
297 start_time = time.clock() | |
298 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
299 done_looping = False |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
300 epoch = 0 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
301 while (epoch < n_epochs) and (not done_looping): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
302 epoch = epoch + 1 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
303 for minibatch_index in xrange(n_train_batches): |
0 | 304 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
305 minibatch_avg_cost = train_model(minibatch_index) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
306 # iteration number |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
307 iter = epoch * n_train_batches + minibatch_index |
0 | 308 |
309 if (iter+1) % validation_frequency == 0: | |
310 # compute zero-one loss on validation set | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
311 validation_losses = [validate_model(i) for i in xrange(n_valid_batches)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
312 this_validation_loss = numpy.mean(validation_losses) |
0 | 313 |
314 print('epoch %i, minibatch %i/%i, validation error %f %%' % \ | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
315 (epoch, minibatch_index+1,n_train_batches, \ |
0 | 316 this_validation_loss*100.)) |
317 | |
318 | |
319 # if we got the best validation score until now | |
320 if this_validation_loss < best_validation_loss: | |
321 #improve patience if loss improvement is good enough | |
322 if this_validation_loss < best_validation_loss * \ | |
323 improvement_threshold : | |
324 patience = max(patience, iter * patience_increase) | |
325 | |
326 best_validation_loss = this_validation_loss | |
327 # test it on the test set | |
328 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
329 test_losses = [test_model(i) for i in xrange(n_test_batches)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
330 test_score = numpy.mean(test_losses) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
331 |
0 | 332 print((' epoch %i, minibatch %i/%i, test error of best ' |
333 'model %f %%') % \ | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
334 (epoch, minibatch_index+1, n_train_batches,test_score*100.)) |
0 | 335 |
336 if patience <= iter : | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
337 done_looping = True |
0 | 338 break |
339 | |
340 end_time = time.clock() | |
341 print(('Optimization complete with best validation score of %f %%,' | |
342 'with test performance %f %%') % | |
343 (best_validation_loss * 100., test_score*100.)) | |
344 print ('The code ran for %f minutes' % ((end_time-start_time)/60.)) | |
345 | |
346 if __name__ == '__main__': | |
347 sgd_optimization_mnist() | |
348 |