changeset 541:5b4ccbf022c8

Working RBM implementation
author desjagui@atchoum.iro.umontreal.ca
date Thu, 13 Nov 2008 18:40:45 -0500
parents 85d3300c9a9c
children de6de7c2c54b 24dfe316e79a
files pylearn/algorithms/rbm.py
diffstat 1 files changed, 47 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/rbm.py	Thu Nov 13 17:54:56 2008 -0500
+++ b/pylearn/algorithms/rbm.py	Thu Nov 13 18:40:45 2008 -0500
@@ -12,14 +12,18 @@
 from .minimizer import make_minimizer
 from .stopper import make_stopper
 
-class RBM(module.FancyModule):
+class RBM(T.RModule):
 
     # is it really necessary to pass ALL of these ? - GD
     def __init__(self,
             nvis=None, nhid=None,
             input=None,
-            w=None, hidb=None, visb=None):
+            w=None, hidb=None, visb=None,
+            seed=0, lr=0.1):
+      
         super(RBM, self).__init__()
+        self.nhid, self.nvis = nhid, nvis
+        self.lr = lr
        
         # symbolic theano stuff
         # what about multidimensional inputs/outputs ? do they have to be 
@@ -27,28 +31,48 @@
         self.w = w if w is not None else module.Member(T.dmatrix())
         self.visb = visb if visb is not None else module.Member(T.dvector())
         self.hidb = hidb if hidb is not None else module.Member(T.dvector())
-
+        self.seed = seed;
+       
         # 1-step Markov chain
-        self.hid = sigmoid(T.dot(self.w,self.input) + self.hidb)
-        self.hid_sample = self.hid #TODO: sample!
-        self.vis = sigmoid(T.dot(self.w.T, self.hid) + self.visb)
-        self.vis_sample = self.vis #TODO: sample!
-        self.neg_hid = sigmoid(T.dot(self.w, self.vis) + self.hidb)
+        vis = T.dmatrix()
+        hid = sigmoid(T.dot(vis, self.w) + self.hidb)
+        hid_sample = self.random.binomial(T.shape(hid), 1, hid)
+        neg_vis = sigmoid(T.dot(hid_sample, self.w.T) + self.visb)
+        neg_vis_sample = self.random.binomial(T.shape(neg_vis), 1, neg_vis)
+        neg_hid = sigmoid(T.dot(neg_vis_sample, self.w) + self.hidb)
+
+        # function which execute 1-step Markov chain (with and without cd updates)
+        self.updownup = module.Method([vis], [hid, neg_vis_sample, neg_hid])
 
-        # cd1 updates:
-        self.params = [self.w, self.visb, self.hidb]
-        self.gradients = [
-            T.dot(self.hid, self.input) - T.dot(self.neg_hid, self.vis),
-            self.input - self.vis,
-            self.hid - self.neg_hid ]
+        # function to perform manual cd update given 2 visible and 2 hidden values
+        vistemp = T.dmatrix()
+        hidtemp = T.dmatrix()
+        nvistemp = T.dmatrix()
+        nhidtemp = T.dmatrix()
+        self.cd_update = module.Method([vistemp, hidtemp, nvistemp, nhidtemp],
+                [],
+                updates = {self.w: self.w + self.lr * 
+                                   (T.dot(vistemp.T, hidtemp) - 
+                                    T.dot(nvistemp.T, nhidtemp)),
+                           self.visb: self.visb + self.lr * 
+                                      (T.sum(vistemp - nvistemp,axis=0)),
+                           self.hidb: self.hidb + self.lr *
+                                      (T.sum(hidtemp - nhidtemp,axis=0))});
 
-    def __instance_initialize(self, obj):
-        obj.w = N.random.standard_normal((self.nhid,self.nvis))
-        obj.genb = N.zeros(self.nvis)
+    # TODO: add parameter for weigth initialization
+    def _instance_initialize(self, obj):
+        obj.w = N.random.standard_normal((self.nvis,self.nhid))
+        obj.visb = N.zeros(self.nvis)
         obj.hidb = N.zeros(self.nhid)
+        obj.seed(self.seed);
 
-def RBM_cd():
-    pass;
+    def _instance_cd1(self, obj, input, k=1):
+        poshid, negvissample, neghid = obj.updownup(input)
+        for i in xrange(k-1):
+            ahid, negvissample, neghid = obj.updownup(negvissample)
+        # CD-k update
+        obj.cd_update(input, poshid, negvissample, neghid)
+
 
 def train_rbm(state, channel=lambda *args, **kwargs:None):
     dataset = make_dataset(**state.subdict(prefix='dataset_'))
@@ -56,10 +80,11 @@
 
     rbm_module = RBM(
             nvis=train.x.shape[1],
-            nhid=state.size_hidden)
+            nhid=state.nhid)
+    rbm = rbm_module.make()
 
-    batchsize = state.batchsize
-    verbose = state.verbose
+    batchsize = getattr(state,'batchsize', 1)
+    verbose = getattr(state,'verbose',1)
     iter = [0]
 
     while iter[0] != state.max_iters: