Mercurial > pylearn
comparison pylearn/algorithms/rbm.py @ 544:de6de7c2c54b
merged and changed state to dictionary
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 17 Nov 2008 20:05:31 -0500 |
parents | ee5324c21e60 5b4ccbf022c8 |
children | 40cae12a9bb8 |
comparison
equal
deleted
inserted
replaced
543:34aba0efa3e9 | 544:de6de7c2c54b |
---|---|
9 import numpy as N | 9 import numpy as N |
10 | 10 |
11 from ..datasets import make_dataset | 11 from ..datasets import make_dataset |
12 from .minimizer import make_minimizer | 12 from .minimizer import make_minimizer |
13 from .stopper import make_stopper | 13 from .stopper import make_stopper |
14 | |
15 from ..dbdict.experiment import subdict | 14 from ..dbdict.experiment import subdict |
16 | 15 |
17 class RBM(module.FancyModule): | 16 class RBM(T.RModule): |
18 | 17 |
19 # is it really necessary to pass ALL of these ? - GD | 18 # is it really necessary to pass ALL of these ? - GD |
20 def __init__(self, | 19 def __init__(self, |
21 nvis=None, nhid=None, | 20 nvis=None, nhid=None, |
22 input=None, | 21 input=None, |
23 w=None, hidb=None, visb=None): | 22 w=None, hidb=None, visb=None, |
23 seed=0, lr=0.1): | |
24 | |
24 super(RBM, self).__init__() | 25 super(RBM, self).__init__() |
26 self.nhid, self.nvis = nhid, nvis | |
27 self.lr = lr | |
25 | 28 |
26 # symbolic theano stuff | 29 # symbolic theano stuff |
27 # what about multidimensional inputs/outputs ? do they have to be | 30 # what about multidimensional inputs/outputs ? do they have to be |
28 # flattened or should we used tensors instead ? | 31 # flattened or should we used tensors instead ? |
29 self.w = w if w is not None else module.Member(T.dmatrix()) | 32 self.w = w if w is not None else module.Member(T.dmatrix()) |
30 self.visb = visb if visb is not None else module.Member(T.dvector()) | 33 self.visb = visb if visb is not None else module.Member(T.dvector()) |
31 self.hidb = hidb if hidb is not None else module.Member(T.dvector()) | 34 self.hidb = hidb if hidb is not None else module.Member(T.dvector()) |
35 self.seed = seed; | |
36 | |
37 # 1-step Markov chain | |
38 vis = T.dmatrix() | |
39 hid = sigmoid(T.dot(vis, self.w) + self.hidb) | |
40 hid_sample = self.random.binomial(T.shape(hid), 1, hid) | |
41 neg_vis = sigmoid(T.dot(hid_sample, self.w.T) + self.visb) | |
42 neg_vis_sample = self.random.binomial(T.shape(neg_vis), 1, neg_vis) | |
43 neg_hid = sigmoid(T.dot(neg_vis_sample, self.w) + self.hidb) | |
32 | 44 |
33 # 1-step Markov chain | 45 # function which execute 1-step Markov chain (with and without cd updates) |
34 self.hid = sigmoid(T.dot(self.w,self.input) + self.hidb) | 46 self.updownup = module.Method([vis], [hid, neg_vis_sample, neg_hid]) |
35 self.hid_sample = self.hid #TODO: sample! | |
36 self.vis = sigmoid(T.dot(self.w.T, self.hid) + self.visb) | |
37 self.vis_sample = self.vis #TODO: sample! | |
38 self.neg_hid = sigmoid(T.dot(self.w, self.vis) + self.hidb) | |
39 | 47 |
40 # cd1 updates: | 48 # function to perform manual cd update given 2 visible and 2 hidden values |
41 self.params = [self.w, self.visb, self.hidb] | 49 vistemp = T.dmatrix() |
42 self.gradients = [ | 50 hidtemp = T.dmatrix() |
43 T.dot(self.hid, self.input) - T.dot(self.neg_hid, self.vis), | 51 nvistemp = T.dmatrix() |
44 self.input - self.vis, | 52 nhidtemp = T.dmatrix() |
45 self.hid - self.neg_hid ] | 53 self.cd_update = module.Method([vistemp, hidtemp, nvistemp, nhidtemp], |
54 [], | |
55 updates = {self.w: self.w + self.lr * | |
56 (T.dot(vistemp.T, hidtemp) - | |
57 T.dot(nvistemp.T, nhidtemp)), | |
58 self.visb: self.visb + self.lr * | |
59 (T.sum(vistemp - nvistemp,axis=0)), | |
60 self.hidb: self.hidb + self.lr * | |
61 (T.sum(hidtemp - nhidtemp,axis=0))}); | |
46 | 62 |
47 def __instance_initialize(self, obj): | 63 # TODO: add parameter for weigth initialization |
48 obj.w = N.random.standard_normal((self.nhid,self.nvis)) | 64 def _instance_initialize(self, obj): |
49 obj.genb = N.zeros(self.nvis) | 65 obj.w = N.random.standard_normal((self.nvis,self.nhid)) |
66 obj.visb = N.zeros(self.nvis) | |
50 obj.hidb = N.zeros(self.nhid) | 67 obj.hidb = N.zeros(self.nhid) |
68 obj.seed(self.seed); | |
51 | 69 |
52 def RBM_cd(): | 70 def _instance_cd1(self, obj, input, k=1): |
53 pass; | 71 poshid, negvissample, neghid = obj.updownup(input) |
72 for i in xrange(k-1): | |
73 ahid, negvissample, neghid = obj.updownup(negvissample) | |
74 # CD-k update | |
75 obj.cd_update(input, poshid, negvissample, neghid) | |
76 | |
54 | 77 |
55 def train_rbm(state, channel=lambda *args, **kwargs:None): | 78 def train_rbm(state, channel=lambda *args, **kwargs:None): |
56 dataset = make_dataset(**subdict_copy(state, prefix='dataset_')) | 79 dataset = make_dataset(**subdict_copy(state, prefix='dataset_')) |
57 train = dataset.train | 80 train = dataset.train |
58 | 81 |
59 rbm_module = RBM( | 82 rbm_module = RBM( |
60 nvis=train.x.shape[1], | 83 nvis=train.x.shape[1], |
61 nhid=state['size_hidden']) | 84 nhid=state['nhid']) |
85 rbm = rbm_module.make() | |
62 | 86 |
63 batchsize = state['batchsize'] | 87 batchsize = state.get('batchsize', 1) |
64 verbose = state['verbose'] | 88 verbose = state.get('verbose', 1) |
65 iter = [0] | 89 iter = [0] |
66 | 90 |
67 while iter[0] != state['max_iters']: | 91 while iter[0] != state['max_iters']: |
68 for j in xrange(0,len(train.x)-batchsize+1,batchsize): | 92 for j in xrange(0,len(train.x)-batchsize+1,batchsize): |
69 rbm.cd1(train.x[j:j+batchsize]) | 93 rbm.cd1(train.x[j:j+batchsize]) |