changeset 974:f2cdcc71ece1

mcRBM - added L1 penalties and normal sign convention to contrastive grad
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 23 Aug 2010 16:02:02 -0400
parents aa201f357d7b
children 38e66e0da66a
files pylearn/algorithms/mcRBM.py
diffstat 1 files changed, 18 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/mcRBM.py	Mon Aug 23 15:59:40 2010 -0400
+++ b/pylearn/algorithms/mcRBM.py	Mon Aug 23 16:02:02 2010 -0400
@@ -161,7 +161,7 @@
         lr = [lr for p in params]
     except TypeError:
         pass
-    updates = [(p, p + plr * gp) for (plr, p, gp) in zip(lr, params, grads)]
+    updates = [(p, p - plr * gp) for (plr, p, gp) in zip(lr, params, grads)]
     return updates
 
 def as_shared(x, name=None, dtype=floatX):
@@ -295,7 +295,7 @@
     def free_energy_given_v(self, v):
         return free_energy_given_v(self.params, v)
 
-    def contrastive_gradient(self, pos_v, neg_v):
+    def contrastive_gradient(self, pos_v, neg_v, U_l1_penalty=0, W_l1_penalty=0):
         """Return a list of gradient expressions for self.params
 
         :param pos_v: positive-phase sample of visible units
@@ -306,7 +306,22 @@
 
         gpos_FE = theano.tensor.grad(pos_FE.sum(), self.params)
         gneg_FE = theano.tensor.grad(neg_FE.sum(), self.params)
-        return [ gn - gp for (gp,gn) in zip(gpos_FE, gneg_FE)]
+        rval = [ gp - gn for (gp,gn) in zip(gpos_FE, gneg_FE)]
+        rval[0] = rval[0] - TT.sign(self.U)*U_l1_penalty
+        rval[1] = rval[1] - TT.sign(self.W)*W_l1_penalty
+        return rval
+
+from pylearn.dataset_ops.protocol import TensorFnDataset
+from pylearn.dataset_ops.memo import memo
+import scipy.io
+@memo
+def load_mcRBM_demo_patches():
+    d = scipy.io.loadmat('/u/bergstrj/cvs/articles/2010/spike_slab_RBM/src/marcaurelio/training_colorpatches_16x16_demo.mat')
+    totnumcases = d["whitendata"].shape[0]
+    #d = d["whitendata"][0:np.floor(totnumcases/batch_size)*batch_size,:].copy() 
+    d = d["whitendata"].copy()
+    return d
+
 
 if __name__ == '__main__':