annotate baseline/mlp/ratio_classes/mlp_nist_ratio.py @ 448:b2a7d93caa0f

Correction d'un petit bug d'indice. Le script est maintenant plus juste
author SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca>
date Fri, 07 May 2010 17:24:21 -0400
parents d8129a09ffb1
children
rev   line source
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
1 # -*- coding: utf-8 -*-
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
2 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
3 This tutorial introduces the multilayer perceptron using Theano.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
4
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
5 A multilayer perceptron is a logistic regressor where
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
6 instead of feeding the input to the logistic regression you insert a
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
7 intermidiate layer, called the hidden layer, that has a nonlinear
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
8 activation function (usually tanh or sigmoid) . One can use many such
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
9 hidden layers making the architecture deep. The tutorial will also tackle
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
10 the problem of MNIST digit classification.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
11
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
12 .. math::
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
13
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
14 f(x) = G( b^{(2)} + W^{(2)}( s( b^{(1)} + W^{(1)} x))),
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
15
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
16 References:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
17
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
18 - textbooks: "Pattern Recognition and Machine Learning" -
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
19 Christopher M. Bishop, section 5
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
20
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
21 TODO: recommended preprocessing, lr ranges, regularization ranges (explain
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
22 to do lr first, then add regularization)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
23
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
24 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
25 __docformat__ = 'restructedtext en'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
26
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
27 import setup_batches
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
28 import pdb
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
29 import numpy
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
30
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
31 import theano
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
32 import theano.tensor as T
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
33 import time
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
34 import theano.tensor.nnet
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
35 import pylearn
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
36 import theano,pylearn.version
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
37 from pylearn.io import filetensor as ft
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
38
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
39 data_path = '/data/lisa/data/nist/by_class/'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
40
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
41 class MLP(object):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
42 """Multi-Layer Perceptron Class
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
43
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
44 A multilayer perceptron is a feedforward artificial neural network model
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
45 that has one layer or more of hidden units and nonlinear activations.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
46 Intermidiate layers usually have as activation function thanh or the
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
47 sigmoid function while the top layer is a softamx layer.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
48 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
49
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
50
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
51
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
52 def __init__(self, input, n_in, n_hidden, n_out,learning_rate, test_subclass):
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
53 """Initialize the parameters for the multilayer perceptron
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
54
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
55 :param input: symbolic variable that describes the input of the
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
56 architecture (one minibatch)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
57
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
58 :param n_in: number of input units, the dimension of the space in
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
59 which the datapoints lie
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
60
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
61 :param n_hidden: number of hidden units
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
62
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
63 :param n_out: number of output units, the dimension of the space in
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
64 which the labels lie
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
65
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
66 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
67
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
68 # initialize the parameters theta = (W1,b1,W2,b2) ; note that this
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
69 # example contains only one hidden layer, but one can have as many
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
70 # layers as he/she wishes, making the network deeper. The only
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
71 # problem making the network deep this way is during learning,
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
72 # backpropagation being unable to move the network from the starting
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
73 # point towards; this is where pre-training helps, giving a good
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
74 # starting point for backpropagation, but more about this in the
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
75 # other tutorials
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
76
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
77 # `W1` is initialized with `W1_values` which is uniformely sampled
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
78 # from -6./sqrt(n_in+n_hidden) and 6./sqrt(n_in+n_hidden)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
79 # the output of uniform if converted using asarray to dtype
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
80 # theano.config.floatX so that the code is runable on GPU
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
81 W1_values = numpy.asarray( numpy.random.uniform( \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
82 low = -numpy.sqrt(6./(n_in+n_hidden)), \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
83 high = numpy.sqrt(6./(n_in+n_hidden)), \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
84 size = (n_in, n_hidden)), dtype = theano.config.floatX)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
85 # `W2` is initialized with `W2_values` which is uniformely sampled
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
86 # from -6./sqrt(n_hidden+n_out) and 6./sqrt(n_hidden+n_out)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
87 # the output of uniform if converted using asarray to dtype
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
88 # theano.config.floatX so that the code is runable on GPU
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
89 W2_values = numpy.asarray( numpy.random.uniform(
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
90 low = -numpy.sqrt(6./(n_hidden+n_out)), \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
91 high= numpy.sqrt(6./(n_hidden+n_out)),\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
92 size= (n_hidden, n_out)), dtype = theano.config.floatX)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
93
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
94 self.W1 = theano.shared( value = W1_values )
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
95 self.b1 = theano.shared( value = numpy.zeros((n_hidden,),
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
96 dtype= theano.config.floatX))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
97 self.W2 = theano.shared( value = W2_values )
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
98 self.b2 = theano.shared( value = numpy.zeros((n_out,),
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
99 dtype= theano.config.floatX))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
100
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
101 #include the learning rate in the classifer so
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
102 #we can modify it on the fly when we want
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
103 lr_value=learning_rate
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
104 self.lr=theano.shared(value=lr_value)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
105 # symbolic expression computing the values of the hidden layer
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
106 self.hidden = T.tanh(T.dot(input, self.W1)+ self.b1)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
107
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
108
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
109
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
110 # symbolic expression computing the values of the top layer
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
111 self.p_y_given_x= T.nnet.softmax(T.dot(self.hidden, self.W2)+self.b2)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
112
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
113 # compute prediction as class whose probability is maximal in
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
114 # symbolic form
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
115 #self.y_pred = T.argmax( self.p_y_given_x, axis =1)
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
116 #self.y_pred_num = T.argmax( self.p_y_given_x[0:9], axis =1)
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
117
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
118 self.test_subclass = test_subclass
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
119
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
120 #if (self.test_subclass == "u"):
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
121 # self.y_pred = T.argmax( self.p_y_given_x[10:35], axis =1) + 10
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
122 #elif (self.test_subclass == "l"):
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
123 # self.y_pred = T.argmax( self.p_y_given_x[35:], axis =1) + 35
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
124 #elif (self.test_subclass == "d"):
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
125 # self.y_pred = T.argmax( self.p_y_given_x[0:9], axis =1)
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
126 #else:
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
127 self.y_pred = T.argmax( self.p_y_given_x, axis =1)
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
128
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
129 # L1 norm ; one regularization option is to enforce L1 norm to
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
130 # be small
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
131 self.L1 = abs(self.W1).sum() + abs(self.W2).sum()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
132
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
133 # square of L2 norm ; one regularization option is to enforce
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
134 # square of L2 norm to be small
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
135 self.L2_sqr = (self.W1**2).sum() + (self.W2**2).sum()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
136
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
137
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
138
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
139 def negative_log_likelihood(self, y):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
140 """Return the mean of the negative log-likelihood of the prediction
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
141 of this model under a given target distribution.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
142
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
143 .. math::
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
144
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
145 \frac{1}{|\mathcal{D}|}\mathcal{L} (\theta=\{W,b\}, \mathcal{D}) =
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
146 \frac{1}{|\mathcal{D}|}\sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
147 \ell (\theta=\{W,b\}, \mathcal{D})
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
148
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
149
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
150 :param y: corresponds to a vector that gives for each example the
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
151 :correct label
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
152 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
153 return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y])
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
154
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
155
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
156
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
157
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
158 def errors(self, y):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
159 """Return a float representing the number of errors in the minibatch
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
160 over the total number of examples of the minibatch
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
161 """
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
162
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
163 # check if y has same dimension of y_pred
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
164 if y.ndim != self.y_pred.ndim:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
165 raise TypeError('y should have the same shape as self.y_pred',
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
166 ('y', target.type, 'y_pred', self.y_pred.type))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
167 # check if y is of the correct datatype
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
168 if y.dtype.startswith('int'):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
169 # the T.neq operator returns a vector of 0s and 1s, where 1
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
170 # represents a mistake in prediction
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
171 return T.mean(T.neq(self.y_pred, y))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
172 else:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
173 raise NotImplementedError()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
174
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
175
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
176 def mlp_full_nist( verbose = False,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
177 adaptive_lr = 1,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
178 train_data = 'all/all_train_data.ft',\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
179 train_labels = 'all/all_train_labels.ft',\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
180 test_data = 'all/all_test_data.ft',\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
181 test_labels = 'all/all_test_labels.ft',\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
182 learning_rate=0.5,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
183 L1_reg = 0.00,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
184 L2_reg = 0.0001,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
185 nb_max_exemples=1000000,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
186 batch_size=20,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
187 nb_hidden = 500,\
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
188 nb_targets = 26,\
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
189 tau=1e6,\
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
190 main_class="l",\
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
191 start_ratio=1,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
192 end_ratio=1):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
193
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
194
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
195 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr]
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
196
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
197 #save initial learning rate if classical adaptive lr is used
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
198 initial_lr=learning_rate
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
199
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
200 total_validation_error_list = []
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
201 total_train_error_list = []
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
202 learning_rate_list=[]
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
203 best_training_error=float('inf');
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
204
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
205 # set up batches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
206 batches = setup_batches.Batches()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
207 batches.set_batches(main_class, start_ratio,end_ratio,batch_size,verbose)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
208
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
209 train_batches = batches.get_train_batches()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
210 test_batches = batches.get_test_batches()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
211 validation_batches = batches.get_validation_batches()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
212
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
213 ishape = (32,32) # this is the size of NIST images
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
214
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
215 # allocate symbolic variables for the data
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
216 x = T.fmatrix() # the data is presented as rasterized images
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
217 y = T.lvector() # the labels are presented as 1D vector of
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
218 # [long int] labels
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
219
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
220 if verbose==True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
221 print 'finished parsing the data'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
222 # construct the logistic regression class
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
223 classifier = MLP( input=x.reshape((batch_size,32*32)),\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
224 n_in=32*32,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
225 n_hidden=nb_hidden,\
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
226 n_out=nb_targets,\
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
227 learning_rate=learning_rate,\
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
228 test_subclass=main_class)
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
229
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
230
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
231
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
232
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
233 # the cost we minimize during training is the negative log likelihood of
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
234 # the model plus the regularization terms (L1 and L2); cost is expressed
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
235 # here symbolically
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
236 cost = classifier.negative_log_likelihood(y) \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
237 + L1_reg * classifier.L1 \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
238 + L2_reg * classifier.L2_sqr
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
239
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
240 # compiling a theano function that computes the mistakes that are made by
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
241 # the model on a minibatch
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
242 test_model = theano.function([x,y], classifier.errors(y))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
243
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
244 # compute the gradient of cost with respect to theta = (W1, b1, W2, b2)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
245 g_W1 = T.grad(cost, classifier.W1)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
246 g_b1 = T.grad(cost, classifier.b1)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
247 g_W2 = T.grad(cost, classifier.W2)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
248 g_b2 = T.grad(cost, classifier.b2)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
249
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
250 # specify how to update the parameters of the model as a dictionary
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
251 updates = \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
252 { classifier.W1: classifier.W1 - classifier.lr*g_W1 \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
253 , classifier.b1: classifier.b1 - classifier.lr*g_b1 \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
254 , classifier.W2: classifier.W2 - classifier.lr*g_W2 \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
255 , classifier.b2: classifier.b2 - classifier.lr*g_b2 }
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
256
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
257 # compiling a theano function `train_model` that returns the cost, but in
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
258 # the same time updates the parameter of the model based on the rules
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
259 # defined in `updates`
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
260 train_model = theano.function([x, y], cost, updates = updates )
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
261 n_minibatches = len(train_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
262
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
263
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
264
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
265
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
266
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
267
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
268 #conditions for stopping the adaptation:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
269 #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
270 #2) validation error is going up twice in a row(probable overfitting)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
271
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
272 # This means we no longer stop on slow convergence as low learning rates stopped
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
273 # too fast.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
274
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
275 # no longer relevant
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
276 patience =nb_max_exemples/batch_size
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
277 patience_increase = 2 # wait this much longer when a new best is
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
278 # found
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
279 improvement_threshold = 0.995 # a relative improvement of this much is
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
280 # considered significant
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
281 validation_frequency = n_minibatches/4
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
282
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
283
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
284
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
285
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
286 best_params = None
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
287 best_validation_loss = float('inf')
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
288 best_iter = 0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
289 test_score = 0.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
290 start_time = time.clock()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
291 n_iter = nb_max_exemples/batch_size # nb of max times we are allowed to run through all exemples
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
292 n_iter = n_iter/n_minibatches + 1 #round up
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
293 n_iter=max(1,n_iter) # run at least once on short debug call
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
294 time_n=0 #in unit of exemples
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
295
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
296 if (main_class == "u"):
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
297 class_offset = 10
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
298 elif (main_class == "l"):
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
299 class_offset = 36
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
300 else:
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
301 class_offset = 0
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
302
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
303
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
304 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
305 print 'looping at most %d times through the data set' %n_iter
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
306 for iter in xrange(n_iter* n_minibatches):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
307
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
308 # get epoch and minibatch index
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
309 epoch = iter / n_minibatches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
310 minibatch_index = iter % n_minibatches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
311
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
312
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
313 if adaptive_lr==2:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
314 classifier.lr.value = tau*initial_lr/(tau+time_n)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
315
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
316 # get the minibatches corresponding to `iter` modulo
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
317 # `len(train_batches)`
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
318 x,y = train_batches[ minibatch_index ]
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
319
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
320 y = y - class_offset
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
321
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
322 # convert to float
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
323 x_float = x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
324 cost_ij = train_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
325
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
326 if (iter+1) % validation_frequency == 0:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
327 # compute zero-one loss on validation set
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
328
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
329 this_validation_loss = 0.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
330 for x,y in validation_batches:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
331 # sum up the errors for each minibatch
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
332 y = y - class_offset
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
333 x_float = x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
334 this_validation_loss += test_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
335 # get the average by dividing with the number of minibatches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
336 this_validation_loss /= len(validation_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
337 #save the validation loss
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
338 total_validation_error_list.append(this_validation_loss)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
339
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
340 #get the training error rate
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
341 this_train_loss=0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
342 for x,y in train_batches:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
343 # sum up the errors for each minibatch
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
344 y = y - class_offset
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
345 x_float = x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
346 this_train_loss += test_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
347 # get the average by dividing with the number of minibatches
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
348 this_train_loss /= len(train_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
349 #save the validation loss
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
350 total_train_error_list.append(this_train_loss)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
351 if(this_train_loss<best_training_error):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
352 best_training_error=this_train_loss
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
353
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
354 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
355 print('epoch %i, minibatch %i/%i, validation error %f, training error %f %%' % \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
356 (epoch, minibatch_index+1, n_minibatches, \
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
357 this_validation_loss*100.,this_train_loss*100))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
358 print 'learning rate = %f' %classifier.lr.value
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
359 print 'time = %i' %time_n
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
360
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
361
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
362 #save the learning rate
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
363 learning_rate_list.append(classifier.lr.value)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
364
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
365
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
366 # if we got the best validation score until now
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
367 if this_validation_loss < best_validation_loss:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
368 # save best validation score and iteration number
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
369 best_validation_loss = this_validation_loss
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
370 best_iter = iter
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
371 # reset patience if we are going down again
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
372 # so we continue exploring
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
373 patience=nb_max_exemples/batch_size
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
374 # test it on the test set
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
375 test_score = 0.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
376 for x,y in test_batches:
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
377 y = y - class_offset
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
378 x_float=x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
379 test_score += test_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
380 test_score /= len(test_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
381 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
382 print((' epoch %i, minibatch %i/%i, test error of best '
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
383 'model %f %%') %
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
384 (epoch, minibatch_index+1, n_minibatches,
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
385 test_score*100.))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
386
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
387 # if the validation error is going up, we are overfitting (or oscillating)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
388 # stop converging but run at least to next validation
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
389 # to check overfitting or ocsillation
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
390 # the saved weights of the model will be a bit off in that case
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
391 elif this_validation_loss >= best_validation_loss:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
392 #calculate the test error at this point and exit
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
393 # test it on the test set
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
394 # however, if adaptive_lr is true, try reducing the lr to
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
395 # get us out of an oscilliation
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
396 if adaptive_lr==1:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
397 classifier.lr.value=classifier.lr.value/2.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
398
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
399 test_score = 0.
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
400 #cap the patience so we are allowed one more validation error
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
401 #calculation before aborting
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
402 patience = iter+validation_frequency+1
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
403 for x,y in test_batches:
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
404 y = y - class_offset
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
405 x_float=x/255.0
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
406 test_score += test_model(x_float,y)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
407 test_score /= len(test_batches)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
408 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
409 print ' validation error is going up, possibly stopping soon'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
410 print((' epoch %i, minibatch %i/%i, test error of best '
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
411 'model %f %%') %
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
412 (epoch, minibatch_index+1, n_minibatches,
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
413 test_score*100.))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
414
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
415
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
416
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
417
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
418 if iter>patience:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
419 print 'we have diverged'
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
420 break
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
421
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
422
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
423 time_n= time_n + batch_size
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
424 end_time = time.clock()
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
425 if verbose == True:
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
426 print(('Optimization complete. Best validation score of %f %% '
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
427 'obtained at iteration %i, with test performance %f %%') %
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
428 (best_validation_loss * 100., best_iter, test_score*100.))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
429 print ('The code ran for %f minutes' % ((end_time-start_time)/60.))
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
430 print iter
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
431
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
432 #save the model and the weights
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
433 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
434 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
435 learning_rate_list=learning_rate_list)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
436
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
437 return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
438
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
439
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
440 if __name__ == '__main__':
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
441 mlp_full_nist(True)
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
442
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
443 def jobman_mlp_full_nist(state,channel):
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
444 (train_error,validation_error,test_error,nb_exemples,time)=mlp_full_nist(learning_rate=state.learning_rate,\
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
445 nb_hidden=state.nb_hidden,\
435
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
446 main_class=state.main_class,\
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
447 start_ratio=state.ratio,\
d8129a09ffb1 bug fix in output
Guillaume Sicard <guitch21@gmail.com>
parents: 357
diff changeset
448 end_ratio=state.ratio)
357
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
449 state.train_error=train_error
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
450 state.validation_error=validation_error
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
451 state.test_error=test_error
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
452 state.nb_exemples=nb_exemples
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
453 state.time=time
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
454 return channel.COMPLETE
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
455
9a7b74927f7d version mlp modifiée pour la selection du ratio de la classe principale
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
456