changeset 539:e3f84d260023

first attempt at RBMs in pylearn
author desjagui@atchoum.iro.umontreal.ca
date Thu, 13 Nov 2008 17:01:44 -0500
parents 798607a058bd
children 85d3300c9a9c
files pylearn/algorithms/rbm.py
diffstat 1 files changed, 46 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/algorithms/rbm.py	Thu Nov 13 17:01:44 2008 -0500
@@ -0,0 +1,46 @@
+import sys, copy
+import theano
+from theano import tensor as T
+from theano.tensor import nnet
+from theano.compile import module
+from theano import printing, pprint
+from theano import compile
+
+import numpy as N
+
+class RBM(module.FancyModule):
+
+    # is it really necessary to pass ALL of these ? - GD
+    def __init__(self,
+            nvis=None, nhid=None,
+            input=None,
+            w=None, hidb=None, visb=None):
+       
+        # symbolic theano stuff
+        # what about multidimensional inputs/outputs ? do they have to be 
+        # flattened or should we used tensors instead ?
+        self.w = w if w is not None else module.Member(T.dmatrix())
+        self.visb = visb if visb is not None else module.Member(T.dvector())
+        self.hidb = hidb if hidb is not None else module.Member(T.dvector())
+
+        # 1-step Markov chain
+        self.hid = T.sigmoid(T.dot(self.w,self.input) + self.hidb)
+        self.hid_sample = self.hid #TODO: sample!
+        self.vis = T.sigmoid(T.dot(self.w.T, self.hid) + self.visb)
+        self.vis_sample = self.vis #TODO: sample!
+        self.neg_hid = T.sigmoid(T.dot(self.w, self.vis) + self.hidb)
+
+        # cd1 updates:
+        self.params = [self.w, self.visb, self.hidb]
+        self.gradients = [
+            T.dot(self.hid, self.input) - T.dot(self.neg_hid, self.vis),
+            self.input - self.vis,
+            self.hid - self.neg_hid ]
+
+    def __instance_initialize(self, obj):
+        obj.w = N.random.standard_normal((self.nhid,self.nvis))
+        obj.genb = N.zeros(self.nvis)
+        obj.hidb = N.zeros(self.nhid)
+
+def RBM_cd():
+    pass;