annotate code_tutoriel/rbm.py @ 0:fda5f787baa6

commit initial
author Dumitru Erhan <dumitru.erhan@gmail.com>
date Thu, 21 Jan 2010 11:26:43 -0500
parents
children
rev   line source
0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
1 import numpy
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
2 import theano
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
3 import theano.tensor as T
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
4
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
5 from theano.compile.sandbox.sharedvalue import shared
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
6 from theano.compile.sandbox.pfunc import pfunc
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
7 from theano.compile.sandbox.shared_randomstreams import RandomStreams
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
8 from theano.tensor.nnet import sigmoid
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
9
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
10 class A():
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
11
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
12 @execute
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
13 def propup();
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
14 # do symbolic prop
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
15 self.hid = T.dot(
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
16
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
17 class RBM():
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
18
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
19 def __init__(self, input=None, vsize=None, hsize=None, bsize=10, lr=1e-1, seed=123):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
20 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
21 RBM constructor. Defines the parameters of the model along with
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
22 basic operations for inferring hidden from visible (and vice-versa), as well
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
23 as for performing CD updates.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
24 param input: None for standalone RBMs or symbolic variable if RBM is
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
25 part of a larger graph.
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
26 param vsize: number of visible units
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
27 param hsize: number of hidden units
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
28 param bsize: size of minibatch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
29 param lr: unsupervised learning rate
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
30 param seed: seed for random number generator
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
31 """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
32 assert vsize and hsize
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
33
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
34 self.vsize = vsize
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
35 self.hsize = hsize
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
36 self.lr = shared(lr, 'lr')
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
37
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
38 # setup theano random number generator
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
39 self.random = RandomStreams(seed)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
40
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
41 #### INITIALIZATION ####
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
42
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
43 # initialize input layer for standalone RBM or layer0 of DBN
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
44 self.input = input if input else T.dmatrix('input')
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
45 # initialize biases
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
46 self.b = shared(numpy.zeros(vsize), 'b')
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
47 self.c = shared(numpy.zeros(hsize), 'c')
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
48 # initialize random weights
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
49 rngseed = numpy.random.RandomState(seed).randint(2**30)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
50 rng = numpy.random.RandomState(rngseed)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
51 ubound = 1./numpy.sqrt(max(self.vsize,self.hsize))
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
52 self.w = shared(rng.uniform(low=-ubound, high=ubound, size=(hsize,vsize)), 'w')
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
53
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
54
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
55 #### POSITIVE AND NEGATIVE PHASE ####
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
56
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
57 # define graph for positive phase
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
58 ph, ph_s = self.def_propup(self.input)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
59 # function which computes p(h|v=x) and ~ p(h|v=x)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
60 self.pos_phase = pfunc([self.input], [ph, ph_s])
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
61
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
62 # define graph for negative phase
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
63 nv, nv_s = self.def_propdown(ph_s)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
64 nh, nh_s = self.def_propup(nv_s)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
65 # function which computes p(v|h=ph_s), ~ p(v|h=ph_s) and p(h|v=nv_s)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
66 self.neg_phase = pfunc([ph_s], [nv, nv_s, nh, nh_s])
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
67
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
68 # calculate CD gradients for each parameter
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
69 db = T.mean(self.input, axis=0) - T.mean(nv, axis=0)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
70 dc = T.mean(ph, axis=0) - T.mean(nh, axis=0)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
71 dwp = T.dot(ph.T, self.input)/nv.shape[0]
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
72 dwn = T.dot(nh.T, nv)/nv.shape[0]
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
73 dw = dwp - dwn
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
74
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
75 # define dictionary of stochastic gradient update equations
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
76 updates = {self.b: self.b - self.lr * db,
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
77 self.c: self.c - self.lr * dc,
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
78 self.w: self.w - self.lr * dw}
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
79
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
80 # define private function, which performs one step in direction of CD gradient
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
81 self.cd_step = pfunc([self.input, ph, nv, nh], [], updates=updates)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
82
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
83
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
84 def def_propup(self, vis):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
85 """ Symbolic definition of p(hid|vis) """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
86 hid_activation = T.dot(vis, self.w.T) + self.c
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
87 hid = sigmoid(hid_activation)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
88 hid_sample = self.random.binomial(T.shape(hid), 1, hid)*1.0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
89 return hid, hid_sample
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
90
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
91 def def_propdown(self, hid):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
92 """ Symbolic definition of p(vis|hid) """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
93 vis_activation = T.dot(hid, self.w) + self.b
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
94 vis = sigmoid(vis_activation)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
95 vis_sample = self.random.binomial(T.shape(vis), 1, vis)*1.0
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
96 return vis, vis_sample
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
97
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
98 def cd(self, x, k=1):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
99 """ Performs actual CD update """
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
100 ph, ph_s = self.pos_phase(x)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
101
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
102 nh_s = ph_s
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
103 for ki in range(k):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
104 nv, nv_s, nh, nh_s = self.neg_phase(nh_s)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
105
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
106 self.cd_step(x, ph, nv_s, nh)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
107
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
108
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
109
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
110 import os
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
111 from pylearn.datasets import MNIST
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
112
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
113 if __name__ == '__main__':
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
114
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
115 bsize = 10
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
116
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
117 # initialize dataset
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
118 dataset = MNIST.first_1k()
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
119 # initialize RBM with 784 visible units and 500 hidden units
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
120 r = RBM(vsize=784, hsize=500, bsize=bsize, lr=0.1)
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
121
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
122 # for a fixed number of epochs ...
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
123 for e in range(10):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
124
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
125 print '@epoch %i ' % e
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
126
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
127 # iterate over all training set mini-batches
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
128 for i in range(len(dataset.train.x)/bsize):
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
129
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
130 rng = range(i*bsize,(i+1)*bsize) # index range of subsequent mini-batch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
131 x = dataset.train.x[rng] # next mini-batch
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
132 r.cd(x) # perform cd update
fda5f787baa6 commit initial
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
133