changeset 651:d03b5d8e4bf6

revised classification and LogReg_New
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 04 Feb 2009 20:02:05 -0500
parents 83e8fe9b1c82
children 2704c8688ced 40cae12a9bb8
files pylearn/algorithms/logistic_regression.py
diffstat 1 files changed, 36 insertions(+), 13 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/logistic_regression.py	Wed Feb 04 18:04:05 2009 -0500
+++ b/pylearn/algorithms/logistic_regression.py	Wed Feb 04 20:02:05 2009 -0500
@@ -196,27 +196,49 @@
 
     @staticmethod
     def xent(p, q):
-        """The cross-entropy between the prediction from `input`, and the true `target`.
+        """cross-entropy (row-wise)
+
+        :type p: M x N symbolic matrix (sparse or dense)
+
+        :param p: each row is a true distribution over N things
+
+        :type q: M x N symbolic matrix (sparse or dense)
 
-        This function returns a symbolic vector, with the cross-entropy for each row in
-        `input`.  
+        :param q: each row is an approximating distribution over N things
+
+        :rtype: symbolic vector of length M
+
+        :returns: the cross entropy between each row of p and the corresponding row of q.
         
-        Hint: To sum these costs into a scalar value, use "xent(input, target).sum()"
+
+        Hint: To sum row-wise costs into a scalar value, use "xent(p, q).sum()"
         """
-        return p * tensor.log(q)
+        return (p * tensor.log(q)).sum(axis=1)
 
     @staticmethod
-    def errors(prediction, target):
-        """The zero-one error of the prediction from `input`, with respect to the true `target`.
+    def errors(target, prediction):
+        """classification error (row-wise)
+
+        :type p: M x N symbolic matrix (sparse or dense)
+
+        :param p: each row is a true distribution over N things
+
+        :type q: M x N symbolic matrix (sparse or dense)
 
-        This function returns a symbolic vector, with the incorrectness of each prediction
-        (made row-wise from `input`).
+        :param q: each row is an approximating distribution over N things
+
+        :rtype: symbolic vector of length M
+
+        :returns: a vector with 0 for every row pair that has a maximum in the same position, 
+        and 1 for every other row pair.
         
+
         Hint: Count errors with "errors(prediction, target).sum()", and get the error-rate with
         "errors(prediction, target).mean()"
-
         """
-        return tensor.neq(tensor.argmax(prediction), target)
+        return tensor.neq(
+                tensor.argmax(prediction, axis=1),
+                tensor.argmax(target, axis=1))
 
 class LogReg_New(module.FancyModule):
     """A symbolic module for performing multi-class logistic regression."""
@@ -234,6 +256,7 @@
 
         self.w = w if w is not None else module.Member(T.dmatrix())
         self.b = b if b is not None else module.Member(T.dvector())
+
     def _instance_initialize(self, obj):
         obj.w = N.zeros((self.n_in, self.n_out))
         obj.b = N.zeros(self.n_out)
@@ -256,8 +279,8 @@
         return tensor.argmax(self.activation(input))
 
     def xent(self, input, target):
-        return classification.xent(self.softmax(input), target)
+        return classification.xent(target, self.softmax(input))
 
     def errors(self, input, target):
-        return classification.errors(self.softmax(input), target)
+        return classification.errors(target, self.softmax(input))