changeset 1358:8cc66dac6430

merge
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 11 Nov 2010 17:36:39 -0500
parents ffa2932a8cba
children 5db730bb0e8e
files pylearn/preprocessing/pca.py
diffstat 1 files changed, 13 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/preprocessing/pca.py	Thu Nov 11 16:34:38 2010 -0500
+++ b/pylearn/preprocessing/pca.py	Thu Nov 11 17:36:39 2010 -0500
@@ -15,12 +15,15 @@
 import numpy
 import scipy.linalg
 
-def diag_as_vector(x):
-    if x.ndim != 2:
-        raise TypeError('this diagonal is implemented only for matrices', x)
-    rval = x[0,:min(*x.shape)]
-    rval.strides = (rval.strides[0] + x.strides[0],)
-    return rval
+if 0:
+    #TODO : put this trick into Theano as an Op
+    #       inplace implementation of diag() Op.
+    def diag_as_vector(x):
+        if x.ndim != 2:
+            raise TypeError('this diagonal is implemented only for matrices', x)
+        rval = x[0,:min(*x.shape)]
+        rval.strides = (rval.strides[0] + x.strides[0],)
+        return rval
 
 
 def pca_from_cov(cov, lower=0, max_components=None, max_energy_fraction=None):
@@ -40,23 +43,12 @@
     #  a * v[:,i] = w[i] * vr[:,i]
     #  v.H * v = identity
 
-    assert w.min() >= -1e-12 # assert w is all pretty much positive
-    if w.min() < 0:
-        for i,wi in enumerate(w):
-            if wi < 0:
-                w[i]=0
-
 
-    # total variance can be computed at this point:
-    # note that vartot == w.sum()
+    # total variance (vartot) can be computed at this point:
     vartot = w.sum()
-    if 0: 
-        # you can do this if you want, but it just slows things down
-        vartot_cov = diag_as_vector(cov).sum()
-        assert numpy.allclose(vartot_cov, vartot)
 
+    # sort the eigenvals and vecs by decreasing magnitude
     a = numpy.argsort(w)[::-1]
-
     w = w[a]
     v = v[:,a]
 
@@ -73,9 +65,8 @@
             while (energy < max_energy_fraction * vartot) and (i < len(w)):
                 energy += w[i]
                 i += 1
-            if i < len(w):
-                w = w[:i]
-                v = v[:,:i]
+            w = w[:(i-1)]
+            v = v[:,:(i-1)]
     return w,v