Mercurial > pylearn
annotate pylearn/algorithms/rbm.py @ 544:de6de7c2c54b
merged and changed state to dictionary
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 17 Nov 2008 20:05:31 -0500 |
parents | ee5324c21e60 5b4ccbf022c8 |
children | 40cae12a9bb8 |
rev | line source |
---|---|
539
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
1 import sys, copy |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
2 import theano |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
3 from theano import tensor as T |
540 | 4 from theano.tensor.nnet import sigmoid |
539
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
5 from theano.compile import module |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
6 from theano import printing, pprint |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
7 from theano import compile |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
8 |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
9 import numpy as N |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
10 |
540 | 11 from ..datasets import make_dataset |
12 from .minimizer import make_minimizer | |
13 from .stopper import make_stopper | |
542
ee5324c21e60
changes to dbdict to use dict-like instead of object-like state
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
540
diff
changeset
|
14 from ..dbdict.experiment import subdict |
540 | 15 |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
16 class RBM(T.RModule): |
539
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
17 |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
18 # is it really necessary to pass ALL of these ? - GD |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
19 def __init__(self, |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
20 nvis=None, nhid=None, |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
21 input=None, |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
22 w=None, hidb=None, visb=None, |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
23 seed=0, lr=0.1): |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
24 |
540 | 25 super(RBM, self).__init__() |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
26 self.nhid, self.nvis = nhid, nvis |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
27 self.lr = lr |
539
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
28 |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
29 # symbolic theano stuff |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
30 # what about multidimensional inputs/outputs ? do they have to be |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
31 # flattened or should we used tensors instead ? |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
32 self.w = w if w is not None else module.Member(T.dmatrix()) |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
33 self.visb = visb if visb is not None else module.Member(T.dvector()) |
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
34 self.hidb = hidb if hidb is not None else module.Member(T.dvector()) |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
35 self.seed = seed; |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
36 |
539
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
37 # 1-step Markov chain |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
38 vis = T.dmatrix() |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
39 hid = sigmoid(T.dot(vis, self.w) + self.hidb) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
40 hid_sample = self.random.binomial(T.shape(hid), 1, hid) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
41 neg_vis = sigmoid(T.dot(hid_sample, self.w.T) + self.visb) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
42 neg_vis_sample = self.random.binomial(T.shape(neg_vis), 1, neg_vis) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
43 neg_hid = sigmoid(T.dot(neg_vis_sample, self.w) + self.hidb) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
44 |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
45 # function which execute 1-step Markov chain (with and without cd updates) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
46 self.updownup = module.Method([vis], [hid, neg_vis_sample, neg_hid]) |
539
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
47 |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
48 # function to perform manual cd update given 2 visible and 2 hidden values |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
49 vistemp = T.dmatrix() |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
50 hidtemp = T.dmatrix() |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
51 nvistemp = T.dmatrix() |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
52 nhidtemp = T.dmatrix() |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
53 self.cd_update = module.Method([vistemp, hidtemp, nvistemp, nhidtemp], |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
54 [], |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
55 updates = {self.w: self.w + self.lr * |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
56 (T.dot(vistemp.T, hidtemp) - |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
57 T.dot(nvistemp.T, nhidtemp)), |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
58 self.visb: self.visb + self.lr * |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
59 (T.sum(vistemp - nvistemp,axis=0)), |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
60 self.hidb: self.hidb + self.lr * |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
61 (T.sum(hidtemp - nhidtemp,axis=0))}); |
539
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
62 |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
63 # TODO: add parameter for weigth initialization |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
64 def _instance_initialize(self, obj): |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
65 obj.w = N.random.standard_normal((self.nvis,self.nhid)) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
66 obj.visb = N.zeros(self.nvis) |
539
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
67 obj.hidb = N.zeros(self.nhid) |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
68 obj.seed(self.seed); |
539
e3f84d260023
first attempt at RBMs in pylearn
desjagui@atchoum.iro.umontreal.ca
parents:
diff
changeset
|
69 |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
70 def _instance_cd1(self, obj, input, k=1): |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
71 poshid, negvissample, neghid = obj.updownup(input) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
72 for i in xrange(k-1): |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
73 ahid, negvissample, neghid = obj.updownup(negvissample) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
74 # CD-k update |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
75 obj.cd_update(input, poshid, negvissample, neghid) |
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
76 |
540 | 77 |
78 def train_rbm(state, channel=lambda *args, **kwargs:None): | |
542
ee5324c21e60
changes to dbdict to use dict-like instead of object-like state
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
540
diff
changeset
|
79 dataset = make_dataset(**subdict_copy(state, prefix='dataset_')) |
540 | 80 train = dataset.train |
81 | |
82 rbm_module = RBM( | |
83 nvis=train.x.shape[1], | |
544
de6de7c2c54b
merged and changed state to dictionary
James Bergstra <bergstrj@iro.umontreal.ca>
diff
changeset
|
84 nhid=state['nhid']) |
541
5b4ccbf022c8
Working RBM implementation
desjagui@atchoum.iro.umontreal.ca
parents:
540
diff
changeset
|
85 rbm = rbm_module.make() |
540 | 86 |
544
de6de7c2c54b
merged and changed state to dictionary
James Bergstra <bergstrj@iro.umontreal.ca>
diff
changeset
|
87 batchsize = state.get('batchsize', 1) |
de6de7c2c54b
merged and changed state to dictionary
James Bergstra <bergstrj@iro.umontreal.ca>
diff
changeset
|
88 verbose = state.get('verbose', 1) |
540 | 89 iter = [0] |
90 | |
542
ee5324c21e60
changes to dbdict to use dict-like instead of object-like state
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
540
diff
changeset
|
91 while iter[0] != state['max_iters']: |
540 | 92 for j in xrange(0,len(train.x)-batchsize+1,batchsize): |
93 rbm.cd1(train.x[j:j+batchsize]) | |
94 if verbose > 1: | |
95 print 'estimated train cost...' | |
542
ee5324c21e60
changes to dbdict to use dict-like instead of object-like state
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
540
diff
changeset
|
96 if iter[0] == state['max_iters']: |
540 | 97 break |
98 else: | |
99 iter[0] += 1 | |
100 |