comparison deep/stacked_dae/v_sylvain/stacked_dae.py @ 363:14b28e43ce4e

Correction d'un bug dans le pre-train du SDA cause par tanh
author SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca>
date Thu, 22 Apr 2010 13:43:04 -0400
parents bc4464c0894c
children d391ad815d89
comparison
equal deleted inserted replaced
362:793e89fcdab7 363:14b28e43ce4e
86 self.W = theano.shared(value = W_values) 86 self.W = theano.shared(value = W_values)
87 87
88 b_values = numpy.zeros((n_out,), dtype= theano.config.floatX) 88 b_values = numpy.zeros((n_out,), dtype= theano.config.floatX)
89 self.b = theano.shared(value= b_values) 89 self.b = theano.shared(value= b_values)
90 90
91 self.output = (T.tanh(T.dot(input, self.W) + self.b) + 1) /2 91 self.output = (T.tanh(T.dot(input, self.W) + self.b) + 1.0)/2.0
92 # ( *+ 1) /2 is because tanh goes from -1 to 1 and sigmoid goes from 0 to 1 92 # ( *+ 1) /2 is because tanh goes from -1 to 1 and sigmoid goes from 0 to 1
93 # I want to use tanh, but the image has to stay the same. The correction is necessary. 93 # I want to use tanh, but the image has to stay the same. The correction is necessary.
94 self.params = [self.W, self.b] 94 self.params = [self.W, self.b]
95 95
96 96
183 # minibatch. We need to compute the average of all these to get 183 # minibatch. We need to compute the average of all these to get
184 # the cost of the minibatch 184 # the cost of the minibatch
185 185
186 #Or use a Tanh everything is always between 0 and 1, the range is 186 #Or use a Tanh everything is always between 0 and 1, the range is
187 #changed so it remain the same as when sigmoid is used 187 #changed so it remain the same as when sigmoid is used
188 self.y = (T.tanh(T.dot(self.tilde_x, self.W ) + self.b)+1.0)/2.0 188 self.y = (T.tanh(T.dot(self.tilde_x, self.W ) + self.b)+1.0)/2.0
189 189
190 z_a = T.dot(self.y, self.W_prime) + self.b_prime 190 z_a = T.dot(self.y, self.W_prime) + self.b_prime
191 self.z = (T.tanh(z_a + self.b_prime)+1.0) / 2.0 191 self.z = (T.tanh(z_a )+1.0) / 2.0
192 #To ensure to do not have a log(0) operation 192 #To ensure to do not have a log(0) operation
193 if self.z <= 0: 193 if self.z <= 0:
194 self.z = 0.000001 194 self.z = 0.000001
195 if self.z >= 1: 195 if self.z >= 1:
196 self.z = 0.999999 196 self.z = 0.999999