comparison mlp.py @ 129:4c2280edcaf5

Fixed typos in learner.py
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Wed, 07 May 2008 21:22:56 -0400
parents 4efe6d36c061
children f6505ec32dc3
comparison
equal deleted inserted replaced
128:ee5507af2c60 129:4c2280edcaf5
5 5
6 # this is one of the simplest example of learner, and illustrates 6 # this is one of the simplest example of learner, and illustrates
7 # the use of theano 7 # the use of theano
8 8
9 9
10 class OneHiddenLayerNNetClassifier(MinibatchUpdatesTLearner): 10 class OneHiddenLayerNNetClassifier(OnlineGradientTLearner):
11 """ 11 """
12 Implement a straightforward classicial feedforward 12 Implement a straightforward classicial feedforward
13 one-hidden-layer neural net, with L2 regularization. 13 one-hidden-layer neural net, with L2 regularization.
14 14
15 The predictor parameters are obtained by minibatch/online gradient descent. 15 The predictor parameters are obtained by minibatch/online gradient descent.
81 self._output_activations =self._b2+t.dot(t.tanh(self._b1+t.dot(self._input,self._W1.T)),self._W2.T) 81 self._output_activations =self._b2+t.dot(t.tanh(self._b1+t.dot(self._input,self._W1.T)),self._W2.T)
82 self._nll,self._output = crossentropy_softmax_1hot(self._output_activations,self._target) 82 self._nll,self._output = crossentropy_softmax_1hot(self._output_activations,self._target)
83 self._output_class = t.argmax(self._output,1) 83 self._output_class = t.argmax(self._output,1)
84 self._class_error = self._output_class != self._target 84 self._class_error = self._output_class != self._target
85 self._minibatch_criterion = self._nll + self._regularization_term / t.shape(self._input)[0] 85 self._minibatch_criterion = self._nll + self._regularization_term / t.shape(self._input)[0]
86 MinibatchUpdatesTLearner.__init__(self) 86 OnlineGradientTLearner.__init__(self)
87 87
88 def attributeNames(self): 88 def attributeNames(self):
89 return ["parameters","b1","W2","b2","W2", "L2_regularizer","regularization_term"] 89 return ["parameters","b1","W2","b2","W2", "L2_regularizer","regularization_term"]
90 90
91 def parameterAttributes(self): 91 def parameterAttributes(self):