Mercurial > ift6266
annotate code_tutoriel/logistic_sgd.py @ 266:1e4e60ddadb1
Merge. Ah, et dans le dernier commit, j'avais oublié de mentionner que j'ai ajouté du code pour gérer l'isolation de différents clones pour rouler des expériences et modifier le code en même temps.
author | fsavard |
---|---|
date | Fri, 19 Mar 2010 10:56:16 -0400 |
parents | 4bc5eeec6394 |
children |
rev | line source |
---|---|
0 | 1 """ |
2 This tutorial introduces logistic regression using Theano and stochastic | |
3 gradient descent. | |
4 | |
5 Logistic regression is a probabilistic, linear classifier. It is parametrized | |
6 by a weight matrix :math:`W` and a bias vector :math:`b`. Classification is | |
7 done by projecting data points onto a set of hyperplanes, the distance to | |
8 which is used to determine a class membership probability. | |
9 | |
10 Mathematically, this can be written as: | |
11 | |
12 .. math:: | |
13 P(Y=i|x, W,b) &= softmax_i(W x + b) \\ | |
14 &= \frac {e^{W_i x + b_i}} {\sum_j e^{W_j x + b_j}} | |
15 | |
16 | |
17 The output of the model or prediction is then done by taking the argmax of | |
18 the vector whose i'th element is P(Y=i|x). | |
19 | |
20 .. math:: | |
21 | |
22 y_{pred} = argmax_i P(Y=i|x,W,b) | |
23 | |
24 | |
25 This tutorial presents a stochastic gradient descent optimization method | |
26 suitable for large datasets, and a conjugate gradient optimization method | |
27 that is suitable for smaller datasets. | |
28 | |
29 | |
30 References: | |
31 | |
32 - textbooks: "Pattern Recognition and Machine Learning" - | |
33 Christopher M. Bishop, section 4.3.2 | |
34 | |
35 """ | |
36 __docformat__ = 'restructedtext en' | |
37 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
38 import numpy, time, cPickle, gzip |
0 | 39 |
40 import theano | |
41 import theano.tensor as T | |
42 | |
43 | |
44 class LogisticRegression(object): | |
45 """Multi-class Logistic Regression Class | |
46 | |
47 The logistic regression is fully described by a weight matrix :math:`W` | |
48 and bias vector :math:`b`. Classification is done by projecting data | |
49 points onto a set of hyperplanes, the distance to which is used to | |
50 determine a class membership probability. | |
51 """ | |
52 | |
53 | |
54 | |
55 | |
56 def __init__(self, input, n_in, n_out): | |
57 """ Initialize the parameters of the logistic regression | |
58 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
59 :type input: theano.tensor.TensorType |
0 | 60 :param input: symbolic variable that describes the input of the |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
61 architecture (one minibatch) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
62 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
63 :type n_in: int |
0 | 64 :param n_in: number of input units, the dimension of the space in |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
65 which the datapoints lie |
0 | 66 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
67 :type n_out: int |
0 | 68 :param n_out: number of output units, the dimension of the space in |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
69 which the labels lie |
0 | 70 |
71 """ | |
72 | |
73 # initialize with 0 the weights W as a matrix of shape (n_in, n_out) | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
74 self.W = theano.shared(value=numpy.zeros((n_in,n_out), dtype = theano.config.floatX), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
75 name='W') |
0 | 76 # initialize the baises b as a vector of n_out 0s |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
77 self.b = theano.shared(value=numpy.zeros((n_out,), dtype = theano.config.floatX), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
78 name='b') |
0 | 79 |
80 | |
81 # compute vector of class-membership probabilities in symbolic form | |
82 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W)+self.b) | |
83 | |
84 # compute prediction as class whose probability is maximal in | |
85 # symbolic form | |
86 self.y_pred=T.argmax(self.p_y_given_x, axis=1) | |
87 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
88 # parameters of the model |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
89 self.params = [self.W, self.b] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
90 |
0 | 91 |
92 | |
93 | |
94 | |
95 def negative_log_likelihood(self, y): | |
96 """Return the mean of the negative log-likelihood of the prediction | |
97 of this model under a given target distribution. | |
98 | |
99 .. math:: | |
100 | |
101 \frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) = | |
102 \frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\ | |
103 \ell (\theta=\{W,b\}, \mathcal{D}) | |
104 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
105 :type y: theano.tensor.TensorType |
0 | 106 :param y: corresponds to a vector that gives for each example the |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
107 correct label |
0 | 108 |
109 Note: we use the mean instead of the sum so that | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
110 the learning rate is less dependent on the batch size |
0 | 111 """ |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
112 # y.shape[0] is (symbolically) the number of rows in y, i.e., number of examples (call it n) in the minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
113 # T.arange(y.shape[0]) is a symbolic vector which will contain [0,1,2,... n-1] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
114 # T.log(self.p_y_given_x) is a matrix of Log-Probabilities (call it LP) with one row per example and one column per class |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
115 # LP[T.arange(y.shape[0]),y] is a vector v containing [LP[0,y[0]], LP[1,y[1]], LP[2,y[2]], ..., LP[n-1,y[n-1]]] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
116 # and T.mean(LP[T.arange(y.shape[0]),y]) is the mean (across minibatch examples) of the elements in v, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
117 # i.e., the mean log-likelihood across the minibatch. |
0 | 118 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y]) |
119 | |
120 | |
121 def errors(self, y): | |
122 """Return a float representing the number of errors in the minibatch | |
123 over the total number of examples of the minibatch ; zero one | |
124 loss over the size of the minibatch | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
125 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
126 :type y: theano.tensor.TensorType |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
127 :param y: corresponds to a vector that gives for each example the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
128 correct label |
0 | 129 """ |
130 | |
131 # check if y has same dimension of y_pred | |
132 if y.ndim != self.y_pred.ndim: | |
133 raise TypeError('y should have the same shape as self.y_pred', | |
134 ('y', target.type, 'y_pred', self.y_pred.type)) | |
135 # check if y is of the correct datatype | |
136 if y.dtype.startswith('int'): | |
137 # the T.neq operator returns a vector of 0s and 1s, where 1 | |
138 # represents a mistake in prediction | |
139 return T.mean(T.neq(self.y_pred, y)) | |
140 else: | |
141 raise NotImplementedError() | |
142 | |
143 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
144 def load_data(dataset): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
145 ''' Loads the dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
146 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
147 :type dataset: string |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
148 :param dataset: the path to the dataset (here MNIST) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
149 ''' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
150 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
151 ############# |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
152 # LOAD DATA # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
153 ############# |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
154 print '... loading data' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
155 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
156 # Load the dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
157 f = gzip.open(dataset,'rb') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
158 train_set, valid_set, test_set = cPickle.load(f) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
159 f.close() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
160 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
161 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
162 def shared_dataset(data_xy): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
163 """ Function that loads the dataset into shared variables |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
164 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
165 The reason we store our dataset in shared variables is to allow |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
166 Theano to copy it into the GPU memory (when code is run on GPU). |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
167 Since copying data into the GPU is slow, copying a minibatch everytime |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
168 is needed (the default behaviour if the data is not in a shared |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
169 variable) would lead to a large decrease in performance. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
170 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
171 data_x, data_y = data_xy |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
172 shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
173 shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
174 # When storing data on the GPU it has to be stored as floats |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
175 # therefore we will store the labels as ``floatX`` as well |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
176 # (``shared_y`` does exactly that). But during our computations |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
177 # we need them as ints (we use labels as index, and if they are |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
178 # floats it doesn't make sense) therefore instead of returning |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
179 # ``shared_y`` we will have to cast it to int. This little hack |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
180 # lets ous get around this issue |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
181 return shared_x, T.cast(shared_y, 'int32') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
182 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
183 test_set_x, test_set_y = shared_dataset(test_set) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
184 valid_set_x, valid_set_y = shared_dataset(valid_set) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
185 train_set_x, train_set_y = shared_dataset(train_set) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
186 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
187 rval = [(train_set_x, train_set_y), (valid_set_x,valid_set_y), (test_set_x, test_set_y)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
188 return rval |
0 | 189 |
190 | |
191 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
192 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
193 def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000, dataset='mnist.pkl.gz'): |
0 | 194 """ |
195 Demonstrate stochastic gradient descent optimization of a log-linear | |
196 model | |
197 | |
198 This is demonstrated on MNIST. | |
199 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
200 :type learning_rate: float |
0 | 201 :param learning_rate: learning rate used (factor for the stochastic |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
202 gradient) |
0 | 203 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
204 :type n_epochs: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
205 :param n_epochs: maximal number of epochs to run the optimizer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
206 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
207 :type dataset: string |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
208 :param dataset: the path of the MNIST dataset file from |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
209 http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz |
0 | 210 |
211 """ | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
212 datasets = load_data(dataset) |
0 | 213 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
214 train_set_x, train_set_y = datasets[0] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
215 valid_set_x, valid_set_y = datasets[1] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
216 test_set_x , test_set_y = datasets[2] |
0 | 217 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
218 batch_size = 600 # size of the minibatch |
0 | 219 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
220 # compute number of minibatches for training, validation and testing |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
221 n_train_batches = train_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
222 n_valid_batches = valid_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
223 n_test_batches = test_set_x.value.shape[0] / batch_size |
0 | 224 |
225 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
226 ###################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
227 # BUILD ACTUAL MODEL # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
228 ###################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
229 print '... building the model' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
230 |
0 | 231 |
232 # allocate symbolic variables for the data | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
233 index = T.lscalar() # index to a [mini]batch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
234 x = T.matrix('x') # the data is presented as rasterized images |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
235 y = T.ivector('y') # the labels are presented as 1D vector of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
236 # [int] labels |
0 | 237 |
238 # construct the logistic regression class | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
239 # Each MNIST image has size 28*28 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
240 classifier = LogisticRegression( input=x, n_in=28*28, n_out=10) |
0 | 241 |
242 # the cost we minimize during training is the negative log likelihood of | |
243 # the model in symbolic format | |
244 cost = classifier.negative_log_likelihood(y) | |
245 | |
246 # compiling a Theano function that computes the mistakes that are made by | |
247 # the model on a minibatch | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
248 test_model = theano.function(inputs = [index], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
249 outputs = classifier.errors(y), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
250 givens={ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
251 x:test_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
252 y:test_set_y[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
253 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
254 validate_model = theano.function( inputs = [index], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
255 outputs = classifier.errors(y), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
256 givens={ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
257 x:valid_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
258 y:valid_set_y[index*batch_size:(index+1)*batch_size]}) |
0 | 259 |
260 # compute the gradient of cost with respect to theta = (W,b) | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
261 g_W = T.grad(cost = cost, wrt = classifier.W) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
262 g_b = T.grad(cost = cost, wrt = classifier.b) |
0 | 263 |
264 # specify how to update the parameters of the model as a dictionary | |
265 updates ={classifier.W: classifier.W - learning_rate*g_W,\ | |
266 classifier.b: classifier.b - learning_rate*g_b} | |
267 | |
268 # compiling a Theano function `train_model` that returns the cost, but in | |
269 # the same time updates the parameter of the model based on the rules | |
270 # defined in `updates` | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
271 train_model = theano.function(inputs = [index], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
272 outputs = cost, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
273 updates = updates, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
274 givens={ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
275 x:train_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
276 y:train_set_y[index*batch_size:(index+1)*batch_size]}) |
0 | 277 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
278 ############### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
279 # TRAIN MODEL # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
280 ############### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
281 print '... training the model' |
0 | 282 # early-stopping parameters |
283 patience = 5000 # look as this many examples regardless | |
284 patience_increase = 2 # wait this much longer when a new best is | |
285 # found | |
286 improvement_threshold = 0.995 # a relative improvement of this much is | |
287 # considered significant | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
288 validation_frequency = min(n_train_batches, patience/2) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
289 # go through this many |
0 | 290 # minibatche before checking the network |
291 # on the validation set; in this case we | |
292 # check every epoch | |
293 | |
294 best_params = None | |
295 best_validation_loss = float('inf') | |
296 test_score = 0. | |
297 start_time = time.clock() | |
298 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
299 done_looping = False |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
300 epoch = 0 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
301 while (epoch < n_epochs) and (not done_looping): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
302 epoch = epoch + 1 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
303 for minibatch_index in xrange(n_train_batches): |
0 | 304 |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
305 minibatch_avg_cost = train_model(minibatch_index) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
306 # iteration number |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
307 iter = epoch * n_train_batches + minibatch_index |
0 | 308 |
309 if (iter+1) % validation_frequency == 0: | |
310 # compute zero-one loss on validation set | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
311 validation_losses = [validate_model(i) for i in xrange(n_valid_batches)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
312 this_validation_loss = numpy.mean(validation_losses) |
0 | 313 |
314 print('epoch %i, minibatch %i/%i, validation error %f %%' % \ | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
315 (epoch, minibatch_index+1,n_train_batches, \ |
0 | 316 this_validation_loss*100.)) |
317 | |
318 | |
319 # if we got the best validation score until now | |
320 if this_validation_loss < best_validation_loss: | |
321 #improve patience if loss improvement is good enough | |
322 if this_validation_loss < best_validation_loss * \ | |
323 improvement_threshold : | |
324 patience = max(patience, iter * patience_increase) | |
325 | |
326 best_validation_loss = this_validation_loss | |
327 # test it on the test set | |
328 | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
329 test_losses = [test_model(i) for i in xrange(n_test_batches)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
330 test_score = numpy.mean(test_losses) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
331 |
0 | 332 print((' epoch %i, minibatch %i/%i, test error of best ' |
333 'model %f %%') % \ | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
334 (epoch, minibatch_index+1, n_train_batches,test_score*100.)) |
0 | 335 |
336 if patience <= iter : | |
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
2
diff
changeset
|
337 done_looping = True |
0 | 338 break |
339 | |
340 end_time = time.clock() | |
341 print(('Optimization complete with best validation score of %f %%,' | |
342 'with test performance %f %%') % | |
343 (best_validation_loss * 100., test_score*100.)) | |
344 print ('The code ran for %f minutes' % ((end_time-start_time)/60.)) | |
345 | |
346 if __name__ == '__main__': | |
347 sgd_optimization_mnist() | |
348 |