annotate pylearn/algorithms/rbm.py @ 544:de6de7c2c54b

merged and changed state to dictionary
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 17 Nov 2008 20:05:31 -0500
parents ee5324c21e60 5b4ccbf022c8
children 40cae12a9bb8
rev   line source
539
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
1 import sys, copy
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
2 import theano
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
3 from theano import tensor as T
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
4 from theano.tensor.nnet import sigmoid
539
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
5 from theano.compile import module
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
6 from theano import printing, pprint
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
7 from theano import compile
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
8
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
9 import numpy as N
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
10
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
11 from ..datasets import make_dataset
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
12 from .minimizer import make_minimizer
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
13 from .stopper import make_stopper
542
ee5324c21e60 changes to dbdict to use dict-like instead of object-like state
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 540
diff changeset
14 from ..dbdict.experiment import subdict
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
15
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
16 class RBM(T.RModule):
539
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
17
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
18 # is it really necessary to pass ALL of these ? - GD
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
19 def __init__(self,
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
20 nvis=None, nhid=None,
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
21 input=None,
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
22 w=None, hidb=None, visb=None,
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
23 seed=0, lr=0.1):
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
24
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
25 super(RBM, self).__init__()
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
26 self.nhid, self.nvis = nhid, nvis
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
27 self.lr = lr
539
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
28
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
29 # symbolic theano stuff
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
30 # what about multidimensional inputs/outputs ? do they have to be
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
31 # flattened or should we used tensors instead ?
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
32 self.w = w if w is not None else module.Member(T.dmatrix())
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
33 self.visb = visb if visb is not None else module.Member(T.dvector())
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
34 self.hidb = hidb if hidb is not None else module.Member(T.dvector())
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
35 self.seed = seed;
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
36
539
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
37 # 1-step Markov chain
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
38 vis = T.dmatrix()
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
39 hid = sigmoid(T.dot(vis, self.w) + self.hidb)
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
40 hid_sample = self.random.binomial(T.shape(hid), 1, hid)
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
41 neg_vis = sigmoid(T.dot(hid_sample, self.w.T) + self.visb)
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
42 neg_vis_sample = self.random.binomial(T.shape(neg_vis), 1, neg_vis)
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
43 neg_hid = sigmoid(T.dot(neg_vis_sample, self.w) + self.hidb)
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
44
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
45 # function which execute 1-step Markov chain (with and without cd updates)
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
46 self.updownup = module.Method([vis], [hid, neg_vis_sample, neg_hid])
539
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
47
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
48 # function to perform manual cd update given 2 visible and 2 hidden values
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
49 vistemp = T.dmatrix()
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
50 hidtemp = T.dmatrix()
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
51 nvistemp = T.dmatrix()
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
52 nhidtemp = T.dmatrix()
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
53 self.cd_update = module.Method([vistemp, hidtemp, nvistemp, nhidtemp],
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
54 [],
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
55 updates = {self.w: self.w + self.lr *
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
56 (T.dot(vistemp.T, hidtemp) -
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
57 T.dot(nvistemp.T, nhidtemp)),
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
58 self.visb: self.visb + self.lr *
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
59 (T.sum(vistemp - nvistemp,axis=0)),
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
60 self.hidb: self.hidb + self.lr *
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
61 (T.sum(hidtemp - nhidtemp,axis=0))});
539
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
62
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
63 # TODO: add parameter for weigth initialization
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
64 def _instance_initialize(self, obj):
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
65 obj.w = N.random.standard_normal((self.nvis,self.nhid))
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
66 obj.visb = N.zeros(self.nvis)
539
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
67 obj.hidb = N.zeros(self.nhid)
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
68 obj.seed(self.seed);
539
e3f84d260023 first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff changeset
69
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
70 def _instance_cd1(self, obj, input, k=1):
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
71 poshid, negvissample, neghid = obj.updownup(input)
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
72 for i in xrange(k-1):
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
73 ahid, negvissample, neghid = obj.updownup(negvissample)
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
74 # CD-k update
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
75 obj.cd_update(input, poshid, negvissample, neghid)
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
76
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
77
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
78 def train_rbm(state, channel=lambda *args, **kwargs:None):
542
ee5324c21e60 changes to dbdict to use dict-like instead of object-like state
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 540
diff changeset
79 dataset = make_dataset(**subdict_copy(state, prefix='dataset_'))
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
80 train = dataset.train
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
81
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
82 rbm_module = RBM(
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
83 nvis=train.x.shape[1],
544
de6de7c2c54b merged and changed state to dictionary
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 542 541
diff changeset
84 nhid=state['nhid'])
541
5b4ccbf022c8 Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents: 540
diff changeset
85 rbm = rbm_module.make()
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
86
544
de6de7c2c54b merged and changed state to dictionary
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 542 541
diff changeset
87 batchsize = state.get('batchsize', 1)
de6de7c2c54b merged and changed state to dictionary
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 542 541
diff changeset
88 verbose = state.get('verbose', 1)
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
89 iter = [0]
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
90
542
ee5324c21e60 changes to dbdict to use dict-like instead of object-like state
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 540
diff changeset
91 while iter[0] != state['max_iters']:
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
92 for j in xrange(0,len(train.x)-batchsize+1,batchsize):
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
93 rbm.cd1(train.x[j:j+batchsize])
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
94 if verbose > 1:
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
95 print 'estimated train cost...'
542
ee5324c21e60 changes to dbdict to use dict-like instead of object-like state
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 540
diff changeset
96 if iter[0] == state['max_iters']:
540
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
97 break
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
98 else:
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
99 iter[0] += 1
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 539
diff changeset
100