annotate baseline/log_reg/log_reg.py @ 222:4cfd0eb438af

Add mnist to datasets (and supporting code).
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 11 Mar 2010 14:41:31 -0500
parents 777f48ba30df
children 7be1f086a89e
rev   line source
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
1 """
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
2 This tutorial introduces logistic regression using Theano and stochastic
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
3 gradient descent.
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
4
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
5 Logistic regression is a probabilistic, linear classifier. It is parametrized
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
6 by a weight matrix :math:`W` and a bias vector :math:`b`. Classification is
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
7 done by projecting data points onto a set of hyperplanes, the distance to
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
8 which is used to determine a class membership probability.
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
9
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
10 Mathematically, this can be written as:
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
11
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
12 .. math::
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
13 P(Y=i|x, W,b) &= softmax_i(W x + b) \\
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
14 &= \frac {e^{W_i x + b_i}} {\sum_j e^{W_j x + b_j}}
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
15
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
16
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
17 The output of the model or prediction is then done by taking the argmax of
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
18 the vector whose i'th element is P(Y=i|x).
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
19
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
20 .. math::
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
21
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
22 y_{pred} = argmax_i P(Y=i|x,W,b)
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
23
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
24
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
25 This tutorial presents a stochastic gradient descent optimization method
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
26 suitable for large datasets, and a conjugate gradient optimization method
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
27 that is suitable for smaller datasets.
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
28
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
29
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
30 References:
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
31
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
32 - textbooks: "Pattern Recognition and Machine Learning" -
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
33 Christopher M. Bishop, section 4.3.2
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
34
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
35 """
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
36 __docformat__ = 'restructedtext en'
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
37
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
38 import numpy, time
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
39
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
40 import theano
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
41 import theano.tensor as T
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
42 from ift6266 import datasets
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
43
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
44 class LogisticRegression(object):
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
45 """Multi-class Logistic Regression Class
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
46
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
47 The logistic regression is fully described by a weight matrix :math:`W`
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
48 and bias vector :math:`b`. Classification is done by projecting data
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
49 points onto a set of hyperplanes, the distance to which is used to
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
50 determine a class membership probability.
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
51 """
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
52
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
53
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
54 def __init__( self, input, n_in, n_out ):
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
55 """ Initialize the parameters of the logistic regression
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
56
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
57 :type input: theano.tensor.TensorType
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
58 :param input: symbolic variable that describes the input of the
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
59 architecture (one minibatch)
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
60
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
61 :type n_in: int
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
62 :param n_in: number of input units, the dimension of the space in
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
63 which the datapoints lie
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
64
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
65 :type n_out: int
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
66 :param n_out: number of output units, the dimension of the space in
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
67 which the labels lie
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
68
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
69 """
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
70
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
71 # initialize with 0 the weights W as a matrix of shape (n_in, n_out)
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
72 self.W = theano.shared( value = numpy.zeros(( n_in, n_out ), dtype = theano.config.floatX ),
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
73 name =' W')
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
74 # initialize the baises b as a vector of n_out 0s
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
75 self.b = theano.shared( value = numpy.zeros(( n_out, ), dtype = theano.config.floatX ),
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
76 name = 'b')
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
77
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
78
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
79 # compute vector of class-membership probabilities in symbolic form
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
80 self.p_y_given_x = T.nnet.softmax( T.dot( input, self.W ) + self.b )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
81
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
82 # compute prediction as class whose probability is maximal in
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
83 # symbolic form
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
84 self.y_pred=T.argmax( self.p_y_given_x, axis =1 )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
85
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
86 # parameters of the model
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
87 self.params = [ self.W, self.b ]
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
88
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
89
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
90 def negative_log_likelihood( self, y ):
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
91 """Return the mean of the negative log-likelihood of the prediction
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
92 of this model under a given target distribution.
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
93
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
94 .. math::
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
95
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
96 \frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) =
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
97 \frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
98 \ell (\theta=\{W,b\}, \mathcal{D})
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
99
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
100 :type y: theano.tensor.TensorType
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
101 :param y: corresponds to a vector that gives for each example the
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
102 correct label
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
103
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
104 Note: we use the mean instead of the sum so that
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
105 the learning rate is less dependent on the batch size
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
106 """
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
107 # y.shape[0] is (symbolically) the number of rows in y, i.e., number of examples (call it n) in the minibatch
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
108 # T.arange(y.shape[0]) is a symbolic vector which will contain [0,1,2,... n-1]
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
109 # T.log(self.p_y_given_x) is a matrix of Log-Probabilities (call it LP) with one row per example and one column per class
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
110 # LP[T.arange(y.shape[0]),y] is a vector v containing [LP[0,y[0]], LP[1,y[1]], LP[2,y[2]], ..., LP[n-1,y[n-1]]]
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
111 # and T.mean(LP[T.arange(y.shape[0]),y]) is the mean (across minibatch examples) of the elements in v,
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
112 # i.e., the mean log-likelihood across the minibatch.
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
113 return -T.mean( T.log( self.p_y_given_x )[ T.arange( y.shape[0] ), y ] )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
114
199
777f48ba30df Add MSE cost to log_reg.py
Arnaud Bergeron <abergeron@gmail.com>
parents: 198
diff changeset
115 def MSE(self, y):
777f48ba30df Add MSE cost to log_reg.py
Arnaud Bergeron <abergeron@gmail.com>
parents: 198
diff changeset
116 return -T.mean(abs((self.p_t_given_x)[T.arange(y.shape[0]), y]-y)**2)
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
117
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
118 def errors( self, y ):
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
119 """Return a float representing the number of errors in the minibatch
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
120 over the total number of examples of the minibatch ; zero one
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
121 loss over the size of the minibatch
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
122
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
123 :type y: theano.tensor.TensorType
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
124 :param y: corresponds to a vector that gives for each example the
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
125 correct label
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
126 """
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
127
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
128 # check if y has same dimension of y_pred
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
129 if y.ndim != self.y_pred.ndim:
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
130 raise TypeError( 'y should have the same shape as self.y_pred',
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
131 ( 'y', target.type, 'y_pred', self.y_pred.type ) )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
132 # check if y is of the correct datatype
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
133 if y.dtype.startswith('int'):
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
134 # the T.neq operator returns a vector of 0s and 1s, where 1
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
135 # represents a mistake in prediction
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
136 return T.mean( T.neq( self.y_pred, y ) )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
137 else:
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
138 raise NotImplementedError()
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
139
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
140 #--------------------------------------------------------------------------------------------------------------------
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
141 # MAIN
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
142 #--------------------------------------------------------------------------------------------------------------------
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
143
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
144 def log_reg( learning_rate = 0.13, nb_max_examples =1000000, batch_size = 50, \
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
145 dataset=datasets.nist_digits, image_size = 32 * 32, nb_class = 10, \
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
146 patience = 5000, patience_increase = 2, improvement_threshold = 0.995):
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
147
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
148 """
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
149 Demonstrate stochastic gradient descent optimization of a log-linear
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
150 model
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
151
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
152 This is demonstrated on MNIST.
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
153
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
154 :type learning_rate: float
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
155 :param learning_rate: learning rate used (factor for the stochastic
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
156 gradient)
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
157
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
158 :type nb_max_examples: int
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
159 :param nb_max_examples: maximal number of epochs to run the optimizer
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
160
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
161 :type batch_size: int
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
162 :param batch_size: size of the minibatch
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
163
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
164 :type dataset: dataset
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
165 :param dataset: a dataset instance from ift6266.datasets
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
166
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
167 :type image_size: int
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
168 :param image_size: size of the input image in pixels (width * height)
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
169
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
170 :type nb_class: int
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
171 :param nb_class: number of classes
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
172
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
173 :type patience: int
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
174 :param patience: look as this many examples regardless
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
175
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
176 :type patience_increase: int
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
177 :param patience_increase: wait this much longer when a new best is found
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
178
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
179 :type improvement_threshold: float
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
180 :param improvement_threshold: a relative improvement of this much is considered significant
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
181
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
182
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
183 """
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
184 #--------------------------------------------------------------------------------------------------------------------
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
185 # Build actual model
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
186 #--------------------------------------------------------------------------------------------------------------------
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
187
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
188 print '... building the model'
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
189
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
190 # allocate symbolic variables for the data
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
191 index = T.lscalar( ) # index to a [mini]batch
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
192 x = T.matrix('x') # the data is presented as rasterized images
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
193 y = T.ivector('y') # the labels are presented as 1D vector of
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
194 # [int] labels
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
195
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
196 # construct the logistic regression class
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
197
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
198 classifier = LogisticRegression( input = x, n_in = image_size, n_out = nb_class )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
199
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
200 # the cost we minimize during training is the negative log likelihood of
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
201 # the model in symbolic format
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
202 cost = classifier.negative_log_likelihood( y )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
203
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
204 # compiling a Theano function that computes the mistakes that are made by
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
205 # the model on a minibatch
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
206 test_model = theano.function( inputs = [ x, y ],
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
207 outputs = classifier.errors( y ))
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
208
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
209 validate_model = theano.function( inputs = [ x, y ],
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
210 outputs = classifier.errors( y ))
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
211
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
212 # compute the gradient of cost with respect to theta = ( W, b )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
213 g_W = T.grad( cost = cost, wrt = classifier.W )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
214 g_b = T.grad( cost = cost, wrt = classifier.b )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
215
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
216 # specify how to update the parameters of the model as a dictionary
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
217 updates = { classifier.W: classifier.W - learning_rate * g_W,\
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
218 classifier.b: classifier.b - learning_rate * g_b}
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
219
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
220 # compiling a Theano function `train_model` that returns the cost, but in
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
221 # the same time updates the parameter of the model based on the rules
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
222 # defined in `updates`
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
223 train_model = theano.function( inputs = [ x, y ],
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
224 outputs = cost,
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
225 updates = updates)
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
226
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
227 #--------------------------------------------------------------------------------------------------------------------
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
228 # Train model
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
229 #--------------------------------------------------------------------------------------------------------------------
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
230
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
231 print '... training the model'
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
232 # early-stopping parameters
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
233 patience = 5000 # look as this many examples regardless
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
234 patience_increase = 2 # wait this much longer when a new best is
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
235 # found
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
236 improvement_threshold = 0.995 # a relative improvement of this much is
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
237 # considered significant
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
238 validation_frequency = patience * 0.5
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
239 # go through this many
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
240 # minibatche before checking the network
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
241 # on the validation set; in this case we
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
242 # check every epoch
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
243
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
244 best_params = None
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
245 best_validation_loss = float('inf')
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
246 test_score = 0.
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
247 start_time = time.clock()
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
248
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
249 done_looping = False
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
250 n_iters = nb_max_examples / batch_size
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
251 epoch = 0
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
252 iter = 0
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
253
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
254 while ( iter < n_iters ) and ( not done_looping ):
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
255
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
256 epoch = epoch + 1
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
257 for x, y in dataset.train(batch_size):
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
258
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
259 minibatch_avg_cost = train_model( x, y )
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
260 # iteration number
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
261 iter += 1
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
262
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
263 if iter % validation_frequency == 0:
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
264 # compute zero-one loss on validation set
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
265 validation_losses = [ validate_model( xv, yv ) for xv, yv in dataset.valid(batch_size) ]
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
266 this_validation_loss = numpy.mean( validation_losses )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
267
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
268 print('epoch %i, iter %i, validation error %f %%' % \
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
269 ( epoch, iter, this_validation_loss*100. ) )
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
270
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
271
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
272 # if we got the best validation score until now
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
273 if this_validation_loss < best_validation_loss:
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
274 #improve patience if loss improvement is good enough
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
275 if this_validation_loss < best_validation_loss * \
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
276 improvement_threshold :
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
277 patience = max( patience, iter * patience_increase )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
278
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
279 best_validation_loss = this_validation_loss
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
280 # test it on the test set
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
281
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
282 test_losses = [test_model(xt, yt) for xt, yt in dataset.test(batch_size)]
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
283 test_score = numpy.mean(test_losses)
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
284
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
285 print((' epoch %i, iter %i, test error of best '
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
286 'model %f %%') % \
198
5d88ed99c0af Modify the log_reg.py tutorial code to use the datasets module.
Arnaud Bergeron <abergeron@gmail.com>
parents: 169
diff changeset
287 (epoch, iter, test_score*100.))
158
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
288
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
289 if patience <= iter :
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
290 done_looping = True
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
291 break
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
292
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
293 end_time = time.clock()
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
294 print(('Optimization complete with best validation score of %f %%,'
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
295 'with test performance %f %%') %
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
296 ( best_validation_loss * 100., test_score * 100.))
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
297 print ('The code ran for %f minutes' % ((end_time-start_time) / 60.))
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
298
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
299 ###### return validation_error, test_error, nb_exemples, time
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
300
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
301 if __name__ == '__main__':
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
302 log_reg()
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
303
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
304
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
305 def jobman_log_reg(state, channel):
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
306 (validation_error, test_error, nb_exemples, time) = log_reg( learning_rate = state.learning_rate,\
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
307 nb_max_examples = state.nb_max_examples,\
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
308 batch_size = state.batch_size,\
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
309 dataset_name = state.dataset_name, \
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
310 image_size = state.image_size, \
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
311 nb_class = state.nb_class )
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
312
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
313 state.validation_error = validation_error
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
314 state.test_error = test_error
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
315 state.nb_exemples = nb_exemples
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
316 state.time = time
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
317 return channel.COMPLETE
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
318
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
319
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
320
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
321
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
322
d1bb6e06497a nouveau répertoire régression logistique
Myriam Cote <cotemyri@iro.umontreal.ca>
parents:
diff changeset
323