comparison mlp.py @ 180:2698c0feeb54

mlp seems to work!
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Tue, 13 May 2008 15:35:43 -0400
parents 9911d2cc3c01
children 4afb41e61fcf
comparison
equal deleted inserted replaced
179:9911d2cc3c01 180:2698c0feeb54
85 self._W2 = t.matrix('W2') 85 self._W2 = t.matrix('W2')
86 self._b1 = t.row('b1') 86 self._b1 = t.row('b1')
87 self._b2 = t.row('b2') 87 self._b2 = t.row('b2')
88 self._regularization_term = self._L2_regularizer * (t.sum(self._W1*self._W1) + t.sum(self._W2*self._W2)) 88 self._regularization_term = self._L2_regularizer * (t.sum(self._W1*self._W1) + t.sum(self._W2*self._W2))
89 self._output_activations =self._b2+t.dot(t.tanh(self._b1+t.dot(self._input,self._W1.T)),self._W2.T) 89 self._output_activations =self._b2+t.dot(t.tanh(self._b1+t.dot(self._input,self._W1.T)),self._W2.T)
90 self._nll,self._output = crossentropy_softmax_1hot(Print("output_activations")(self._output_activations),self._target_vector) 90 self._nll,self._output = crossentropy_softmax_1hot(self._output_activations,self._target_vector)
91 self._output_class = t.argmax(self._output,1) 91 self._output_class = t.argmax(self._output,1)
92 self._class_error = t.neq(self._output_class,self._target_vector) 92 self._class_error = t.neq(self._output_class,self._target_vector)
93 self._minibatch_criterion = self._nll + self._regularization_term / t.shape(self._input)[0] 93 self._minibatch_criterion = self._nll + self._regularization_term / t.shape(self._input)[0]
94 OnlineGradientTLearner.__init__(self) 94 OnlineGradientTLearner.__init__(self)
95 95
99 def parameterAttributes(self): 99 def parameterAttributes(self):
100 return ["b1","W1", "b2", "W2"] 100 return ["b1","W1", "b2", "W2"]
101 101
102 def updateMinibatchInputFields(self): 102 def updateMinibatchInputFields(self):
103 return ["input","target"] 103 return ["input","target"]
104
105 def updateMinibatchInputAttributes(self):
106 return OnlineGradientTLearner.updateMinibatchInputAttributes(self)+["L2_regularizer"]
104 107
105 def updateEndOutputAttributes(self): 108 def updateEndOutputAttributes(self):
106 return ["regularization_term"] 109 return ["regularization_term"]
107 110
108 def lossAttribute(self): 111 def lossAttribute(self):
139 142
140 def isLastEpoch(self): 143 def isLastEpoch(self):
141 self._n_epochs +=1 144 self._n_epochs +=1
142 return self._n_epochs>=self._max_n_epochs 145 return self._n_epochs>=self._max_n_epochs
143 146
144 def updateMinibatch(self,minibatch): 147 def debug_updateMinibatch(self,minibatch):
145 # make sure all required fields are allocated and initialized 148 # make sure all required fields are allocated and initialized
146 self.allocate(minibatch) 149 self.allocate(minibatch)
147 input_attributes = self.names2attributes(self.updateMinibatchInputAttributes()) 150 input_attributes = self.names2attributes(self.updateMinibatchInputAttributes())
148 input_fields = minibatch(*self.updateMinibatchInputFields()) 151 input_fields = minibatch(*self.updateMinibatchInputFields())
149 print 'input attributes', input_attributes 152 print 'input attributes', input_attributes