annotate baseline/mlp/ratio_classes/mlp_nist_ratio.py @ 409:f0c2e3cfb1f1

added some images to illustrate the transformation
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Wed, 28 Apr 2010 16:39:10 -0400
parents 9a7b74927f7d
children d8129a09ffb1
rev   line source
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
1 # -*- coding: utf-8 -*-
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
2 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
3 This tutorial introduces the multilayer perceptron using Theano.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
4
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
5 A multilayer perceptron is a logistic regressor where
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
6 instead of feeding the input to the logistic regression you insert a
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
7 intermidiate layer, called the hidden layer, that has a nonlinear
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
8 activation function (usually tanh or sigmoid) . One can use many such
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
9 hidden layers making the architecture deep. The tutorial will also tackle
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
10 the problem of MNIST digit classification.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
11
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
12 .. math::
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
13
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
14 f(x) = G( b^{(2)} + W^{(2)}( s( b^{(1)} + W^{(1)} x))),
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
15
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
16 References:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
17
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
18 - textbooks: "Pattern Recognition and Machine Learning" -
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
19 Christopher M. Bishop, section 5
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
20
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
21 TODO: recommended preprocessing, lr ranges, regularization ranges (explain
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
22 to do lr first, then add regularization)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
23
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
24 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
25 __docformat__ = 'restructedtext en'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
26
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
27 import ift6266
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
28 from scripts import setup_batches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
29 import pdb
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
30 import numpy
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
31
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
32 import theano
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
33 import theano.tensor as T
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
34 import time
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
35 import theano.tensor.nnet
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
36 import pylearn
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
37 import theano,pylearn.version
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
38 from pylearn.io import filetensor as ft
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
39
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
40 data_path = '/data/lisa/data/nist/by_class/'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
41
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
42 class MLP(object):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
43 """Multi-Layer Perceptron Class
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
44
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
45 A multilayer perceptron is a feedforward artificial neural network model
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
46 that has one layer or more of hidden units and nonlinear activations.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
47 Intermidiate layers usually have as activation function thanh or the
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
48 sigmoid function while the top layer is a softamx layer.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
49 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
50
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
51
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
52
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
53 def __init__(self, input, n_in, n_hidden, n_out,learning_rate):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
54 """Initialize the parameters for the multilayer perceptron
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
55
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
56 :param input: symbolic variable that describes the input of the
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
57 architecture (one minibatch)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
58
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
59 :param n_in: number of input units, the dimension of the space in
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
60 which the datapoints lie
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
61
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
62 :param n_hidden: number of hidden units
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
63
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
64 :param n_out: number of output units, the dimension of the space in
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
65 which the labels lie
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
66
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
67 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
68
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
69 # initialize the parameters theta = (W1,b1,W2,b2) ; note that this
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
70 # example contains only one hidden layer, but one can have as many
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
71 # layers as he/she wishes, making the network deeper. The only
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
72 # problem making the network deep this way is during learning,
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
73 # backpropagation being unable to move the network from the starting
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
74 # point towards; this is where pre-training helps, giving a good
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
75 # starting point for backpropagation, but more about this in the
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
76 # other tutorials
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
77
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
78 # `W1` is initialized with `W1_values` which is uniformely sampled
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
79 # from -6./sqrt(n_in+n_hidden) and 6./sqrt(n_in+n_hidden)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
80 # the output of uniform if converted using asarray to dtype
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
81 # theano.config.floatX so that the code is runable on GPU
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
82 W1_values = numpy.asarray( numpy.random.uniform( \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
83 low = -numpy.sqrt(6./(n_in+n_hidden)), \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
84 high = numpy.sqrt(6./(n_in+n_hidden)), \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
85 size = (n_in, n_hidden)), dtype = theano.config.floatX)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
86 # `W2` is initialized with `W2_values` which is uniformely sampled
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
87 # from -6./sqrt(n_hidden+n_out) and 6./sqrt(n_hidden+n_out)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
88 # the output of uniform if converted using asarray to dtype
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
89 # theano.config.floatX so that the code is runable on GPU
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
90 W2_values = numpy.asarray( numpy.random.uniform(
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
91 low = -numpy.sqrt(6./(n_hidden+n_out)), \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
92 high= numpy.sqrt(6./(n_hidden+n_out)),\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
93 size= (n_hidden, n_out)), dtype = theano.config.floatX)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
94
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
95 self.W1 = theano.shared( value = W1_values )
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
96 self.b1 = theano.shared( value = numpy.zeros((n_hidden,),
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
97 dtype= theano.config.floatX))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
98 self.W2 = theano.shared( value = W2_values )
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
99 self.b2 = theano.shared( value = numpy.zeros((n_out,),
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
100 dtype= theano.config.floatX))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
101
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
102 #include the learning rate in the classifer so
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
103 #we can modify it on the fly when we want
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
104 lr_value=learning_rate
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
105 self.lr=theano.shared(value=lr_value)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
106 # symbolic expression computing the values of the hidden layer
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
107 self.hidden = T.tanh(T.dot(input, self.W1)+ self.b1)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
108
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
109
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
110
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
111 # symbolic expression computing the values of the top layer
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
112 self.p_y_given_x= T.nnet.softmax(T.dot(self.hidden, self.W2)+self.b2)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
113
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
114 # compute prediction as class whose probability is maximal in
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
115 # symbolic form
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
116 self.y_pred = T.argmax( self.p_y_given_x, axis =1)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
117 self.y_pred_num = T.argmax( self.p_y_given_x[0:9], axis =1)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
118
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
119
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
120
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
121
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
122 # L1 norm ; one regularization option is to enforce L1 norm to
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
123 # be small
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
124 self.L1 = abs(self.W1).sum() + abs(self.W2).sum()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
125
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
126 # square of L2 norm ; one regularization option is to enforce
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
127 # square of L2 norm to be small
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
128 self.L2_sqr = (self.W1**2).sum() + (self.W2**2).sum()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
129
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
130
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
131
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
132 def negative_log_likelihood(self, y):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
133 """Return the mean of the negative log-likelihood of the prediction
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
134 of this model under a given target distribution.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
135
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
136 .. math::
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
137
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
138 \frac{1}{|\mathcal{D}|}\mathcal{L} (\theta=\{W,b\}, \mathcal{D}) =
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
139 \frac{1}{|\mathcal{D}|}\sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
140 \ell (\theta=\{W,b\}, \mathcal{D})
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
141
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
142
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
143 :param y: corresponds to a vector that gives for each example the
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
144 :correct label
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
145 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
146 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y])
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
147
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
148
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
149
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
150
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
151 def errors(self, y):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
152 """Return a float representing the number of errors in the minibatch
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
153 over the total number of examples of the minibatch
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
154 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
155
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
156 # check if y has same dimension of y_pred
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
157 if y.ndim != self.y_pred.ndim:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
158 raise TypeError('y should have the same shape as self.y_pred',
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
159 ('y', target.type, 'y_pred', self.y_pred.type))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
160 # check if y is of the correct datatype
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
161 if y.dtype.startswith('int'):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
162 # the T.neq operator returns a vector of 0s and 1s, where 1
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
163 # represents a mistake in prediction
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
164 return T.mean(T.neq(self.y_pred, y))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
165 else:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
166 raise NotImplementedError()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
167
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
168
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
169 def mlp_full_nist( verbose = False,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
170 adaptive_lr = 1,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
171 train_data = 'all/all_train_data.ft',\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
172 train_labels = 'all/all_train_labels.ft',\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
173 test_data = 'all/all_test_data.ft',\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
174 test_labels = 'all/all_test_labels.ft',\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
175 learning_rate=0.5,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
176 L1_reg = 0.00,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
177 L2_reg = 0.0001,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
178 nb_max_exemples=1000000,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
179 batch_size=20,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
180 nb_hidden = 500,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
181 nb_targets = 62,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
182 tau=1e6,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
183 main_class="d",\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
184 start_ratio=1,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
185 end_ratio=1):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
186
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
187
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
188 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr]
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
189
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
190 #save initial learning rate if classical adaptive lr is used
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
191 initial_lr=learning_rate
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
192
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
193 total_validation_error_list = []
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
194 total_train_error_list = []
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
195 learning_rate_list=[]
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
196 best_training_error=float('inf');
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
197
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
198 # set up batches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
199 batches = setup_batches.Batches()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
200 batches.set_batches(main_class, start_ratio,end_ratio,batch_size,verbose)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
201
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
202 train_batches = batches.get_train_batches()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
203 test_batches = batches.get_test_batches()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
204 validation_batches = batches.get_validation_batches()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
205
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
206 ishape = (32,32) # this is the size of NIST images
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
207
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
208 # allocate symbolic variables for the data
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
209 x = T.fmatrix() # the data is presented as rasterized images
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
210 y = T.lvector() # the labels are presented as 1D vector of
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
211 # [long int] labels
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
212
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
213 if verbose==True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
214 print 'finished parsing the data'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
215 # construct the logistic regression class
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
216 classifier = MLP( input=x.reshape((batch_size,32*32)),\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
217 n_in=32*32,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
218 n_hidden=nb_hidden,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
219 n_out=nb_targets,
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
220 learning_rate=learning_rate)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
221
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
222
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
223
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
224
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
225 # the cost we minimize during training is the negative log likelihood of
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
226 # the model plus the regularization terms (L1 and L2); cost is expressed
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
227 # here symbolically
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
228 cost = classifier.negative_log_likelihood(y) \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
229 + L1_reg * classifier.L1 \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
230 + L2_reg * classifier.L2_sqr
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
231
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
232 # compiling a theano function that computes the mistakes that are made by
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
233 # the model on a minibatch
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
234 test_model = theano.function([x,y], classifier.errors(y))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
235
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
236 # compute the gradient of cost with respect to theta = (W1, b1, W2, b2)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
237 g_W1 = T.grad(cost, classifier.W1)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
238 g_b1 = T.grad(cost, classifier.b1)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
239 g_W2 = T.grad(cost, classifier.W2)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
240 g_b2 = T.grad(cost, classifier.b2)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
241
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
242 # specify how to update the parameters of the model as a dictionary
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
243 updates = \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
244 { classifier.W1: classifier.W1 - classifier.lr*g_W1 \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
245 , classifier.b1: classifier.b1 - classifier.lr*g_b1 \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
246 , classifier.W2: classifier.W2 - classifier.lr*g_W2 \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
247 , classifier.b2: classifier.b2 - classifier.lr*g_b2 }
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
248
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
249 # compiling a theano function `train_model` that returns the cost, but in
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
250 # the same time updates the parameter of the model based on the rules
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
251 # defined in `updates`
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
252 train_model = theano.function([x, y], cost, updates = updates )
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
253 n_minibatches = len(train_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
254
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
255
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
256
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
257
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
258
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
259
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
260 #conditions for stopping the adaptation:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
261 #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
262 #2) validation error is going up twice in a row(probable overfitting)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
263
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
264 # This means we no longer stop on slow convergence as low learning rates stopped
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
265 # too fast.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
266
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
267 # no longer relevant
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
268 patience =nb_max_exemples/batch_size
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
269 patience_increase = 2 # wait this much longer when a new best is
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
270 # found
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
271 improvement_threshold = 0.995 # a relative improvement of this much is
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
272 # considered significant
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
273 validation_frequency = n_minibatches/4
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
274
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
275
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
276
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
277
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
278 best_params = None
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
279 best_validation_loss = float('inf')
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
280 best_iter = 0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
281 test_score = 0.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
282 start_time = time.clock()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
283 n_iter = nb_max_exemples/batch_size # nb of max times we are allowed to run through all exemples
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
284 n_iter = n_iter/n_minibatches + 1 #round up
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
285 n_iter=max(1,n_iter) # run at least once on short debug call
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
286 time_n=0 #in unit of exemples
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
287
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
288
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
289
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
290 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
291 print 'looping at most %d times through the data set' %n_iter
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
292 for iter in xrange(n_iter* n_minibatches):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
293
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
294 # get epoch and minibatch index
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
295 epoch = iter / n_minibatches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
296 minibatch_index = iter % n_minibatches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
297
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
298
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
299 if adaptive_lr==2:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
300 classifier.lr.value = tau*initial_lr/(tau+time_n)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
301
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
302 # get the minibatches corresponding to `iter` modulo
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
303 # `len(train_batches)`
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
304 x,y = train_batches[ minibatch_index ]
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
305 # convert to float
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
306 x_float = x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
307 cost_ij = train_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
308
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
309 if (iter+1) % validation_frequency == 0:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
310 # compute zero-one loss on validation set
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
311
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
312 this_validation_loss = 0.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
313 for x,y in validation_batches:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
314 # sum up the errors for each minibatch
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
315 x_float = x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
316 this_validation_loss += test_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
317 # get the average by dividing with the number of minibatches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
318 this_validation_loss /= len(validation_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
319 #save the validation loss
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
320 total_validation_error_list.append(this_validation_loss)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
321
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
322 #get the training error rate
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
323 this_train_loss=0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
324 for x,y in train_batches:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
325 # sum up the errors for each minibatch
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
326 x_float = x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
327 this_train_loss += test_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
328 # get the average by dividing with the number of minibatches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
329 this_train_loss /= len(train_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
330 #save the validation loss
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
331 total_train_error_list.append(this_train_loss)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
332 if(this_train_loss<best_training_error):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
333 best_training_error=this_train_loss
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
334
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
335 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
336 print('epoch %i, minibatch %i/%i, validation error %f, training error %f %%' % \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
337 (epoch, minibatch_index+1, n_minibatches, \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
338 this_validation_loss*100.,this_train_loss*100))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
339 print 'learning rate = %f' %classifier.lr.value
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
340 print 'time = %i' %time_n
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
341
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
342
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
343 #save the learning rate
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
344 learning_rate_list.append(classifier.lr.value)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
345
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
346
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
347 # if we got the best validation score until now
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
348 if this_validation_loss < best_validation_loss:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
349 # save best validation score and iteration number
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
350 best_validation_loss = this_validation_loss
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
351 best_iter = iter
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
352 # reset patience if we are going down again
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
353 # so we continue exploring
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
354 patience=nb_max_exemples/batch_size
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
355 # test it on the test set
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
356 test_score = 0.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
357 for x,y in test_batches:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
358 x_float=x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
359 test_score += test_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
360 test_score /= len(test_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
361 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
362 print((' epoch %i, minibatch %i/%i, test error of best '
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
363 'model %f %%') %
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
364 (epoch, minibatch_index+1, n_minibatches,
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
365 test_score*100.))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
366
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
367 # if the validation error is going up, we are overfitting (or oscillating)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
368 # stop converging but run at least to next validation
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
369 # to check overfitting or ocsillation
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
370 # the saved weights of the model will be a bit off in that case
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
371 elif this_validation_loss >= best_validation_loss:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
372 #calculate the test error at this point and exit
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
373 # test it on the test set
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
374 # however, if adaptive_lr is true, try reducing the lr to
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
375 # get us out of an oscilliation
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
376 if adaptive_lr==1:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
377 classifier.lr.value=classifier.lr.value/2.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
378
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
379 test_score = 0.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
380 #cap the patience so we are allowed one more validation error
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
381 #calculation before aborting
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
382 patience = iter+validation_frequency+1
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
383 for x,y in test_batches:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
384 x_float=x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
385 test_score += test_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
386 test_score /= len(test_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
387 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
388 print ' validation error is going up, possibly stopping soon'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
389 print((' epoch %i, minibatch %i/%i, test error of best '
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
390 'model %f %%') %
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
391 (epoch, minibatch_index+1, n_minibatches,
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
392 test_score*100.))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
393
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
394
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
395
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
396
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
397 if iter>patience:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
398 print 'we have diverged'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
399 break
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
400
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
401
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
402 time_n= time_n + batch_size
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
403 end_time = time.clock()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
404 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
405 print(('Optimization complete. Best validation score of %f %% '
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
406 'obtained at iteration %i, with test performance %f %%') %
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
407 (best_validation_loss * 100., best_iter, test_score*100.))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
408 print ('The code ran for %f minutes' % ((end_time-start_time)/60.))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
409 print iter
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
410
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
411 #save the model and the weights
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
412 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
413 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
414 learning_rate_list=learning_rate_list)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
415
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
416 return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
417
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
418
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
419 if __name__ == '__main__':
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
420 mlp_full_nist(True)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
421
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
422 def jobman_mlp_full_nist(state,channel):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
423 (train_error,validation_error,test_error,nb_exemples,time)=mlp_full_nist(learning_rate=state.learning_rate,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
424 nb_max_exemples=state.nb_max_exemples,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
425 nb_hidden=state.nb_hidden,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
426 adaptive_lr=state.adaptive_lr,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
427 tau=state.tau,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
428 main_class=state.main_class,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
429 start_ratio=state.start_ratio,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
430 end_ratio=state.end_ratio)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
431 state.train_error=train_error
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
432 state.validation_error=validation_error
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
433 state.test_error=test_error
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
434 state.nb_exemples=nb_exemples
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
435 state.time=time
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
436 return channel.COMPLETE
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
437
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
438