comparison deep/rbm/mnistrbm.py @ 348:45156cbf6722

training an rbm using cd or pcd
author goldfinger
date Mon, 19 Apr 2010 08:17:45 -0400
parents
children
comparison
equal deleted inserted replaced
347:9685e9d94cc4 348:45156cbf6722
1 import sys
2 import os, os.path
3
4 import numpy as N
5
6 import theano
7 import theano.tensor as T
8
9 from crbm import CRBM, ConvolutionParams
10
11 from pylearn.datasets import MNIST
12 from pylearn.io.image_tiling import tile_raster_images
13
14 import Image
15
16 from pylearn.io.seriestables import *
17 import tables
18
19 IMAGE_OUTPUT_DIR = 'img/'
20
21 REDUCE_EVERY = 100
22
23 def filename_from_time(suffix):
24 import datetime
25 return str(datetime.datetime.now()) + suffix + ".png"
26
27 # Just a shortcut for a common case where we need a few
28 # related Error (float) series
29
30 def get_accumulator_series_array( \
31 hdf5_file, group_name, series_names,
32 reduce_every,
33 index_names=('epoch','minibatch'),
34 stdout_too=True,
35 skip_hdf5_append=False):
36 all_series = []
37
38 hdf5_file.createGroup('/', group_name)
39
40 other_targets = []
41 if stdout_too:
42 other_targets = [StdoutAppendTarget()]
43
44 for sn in series_names:
45 series_base = \
46 ErrorSeries(error_name=sn,
47 table_name=sn,
48 hdf5_file=hdf5_file,
49 hdf5_group='/'+group_name,
50 index_names=index_names,
51 other_targets=other_targets,
52 skip_hdf5_append=skip_hdf5_append)
53
54 all_series.append( \
55 AccumulatorSeriesWrapper( \
56 base_series=series_base,
57 reduce_every=reduce_every))
58
59 ret_wrapper = SeriesArrayWrapper(all_series)
60
61 return ret_wrapper
62
63 class ExperienceRbm(object):
64 def __init__(self):
65 self.mnist = MNIST.full()#first_10k()
66
67
68 datasets = load_data(dataset)
69
70 train_set_x, train_set_y = datasets[0]
71 test_set_x , test_set_y = datasets[2]
72
73
74 batch_size = 100 # size of the minibatch
75
76 # compute number of minibatches for training, validation and testing
77 n_train_batches = train_set_x.value.shape[0] / batch_size
78
79 # allocate symbolic variables for the data
80 index = T.lscalar() # index to a [mini]batch
81 x = T.matrix('x') # the data is presented as rasterized images
82
83 rng = numpy.random.RandomState(123)
84 theano_rng = RandomStreams( rng.randint(2**30))
85
86 # initialize storage fot the persistent chain (state = hidden layer of chain)
87 persistent_chain = theano.shared(numpy.zeros((batch_size, 500)))
88
89 # construct the RBM class
90 self.rbm = RBM( input = x, n_visible=28*28, \
91 n_hidden = 500,numpy_rng = rng, theano_rng = theano_rng)
92
93 # get the cost and the gradient corresponding to one step of CD
94
95
96 self.init_series()
97
98 def init_series(self):
99
100 series = {}
101
102 basedir = os.getcwd()
103
104 h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w")
105
106 cd_series_names = self.rbm.cd_return_desc
107 series['cd'] = \
108 get_accumulator_series_array( \
109 h5f, 'cd', cd_series_names,
110 REDUCE_EVERY,
111 stdout_too=True)
112
113
114
115 # so first we create the names for each table, based on
116 # position of each param in the array
117 params_stdout = StdoutAppendTarget("\n------\nParams")
118 series['params'] = SharedParamsStatisticsWrapper(
119 new_group_name="params",
120 base_group="/",
121 arrays_names=['W','b_h','b_x'],
122 hdf5_file=h5f,
123 index_names=('epoch','minibatch'),
124 other_targets=[params_stdout])
125
126 self.series = series
127
128 def train(self, persistent, learning_rate):
129
130 training_epochs = 15
131
132 #get the cost and the gradient corresponding to one step of CD
133 if persistant:
134 persistent_chain = theano.shared(numpy.zeros((batch_size, self.rbm.n_hidden)))
135 cost, updates = self.rbm.cd(lr=learning_rate, persistent=persistent_chain)
136
137 else:
138 cost, updates = self.rbm.cd(lr=learning_rate)
139
140 dirname = 'lr=%.5f'%self.rbm.learning_rate
141 os.makedirs(dirname)
142 os.chdir(dirname)
143
144 # the purpose of train_rbm is solely to update the RBM parameters
145 train_rbm = theano.function([index], cost,
146 updates = updates,
147 givens = { x: train_set_x[index*batch_size:(index+1)*batch_size]})
148
149 plotting_time = 0.
150 start_time = time.clock()
151
152
153 # go through training epochs
154 for epoch in xrange(training_epochs):
155
156 # go through the training set
157 mean_cost = []
158 for batch_index in xrange(n_train_batches):
159 mean_cost += [train_rbm(batch_index)]
160
161
162 pretraining_time = (end_time - start_time)
163
164
165
166
167 def sample_from_rbm(self, gibbs_steps, test_set_x):
168
169 # find out the number of test samples
170 number_of_test_samples = test_set_x.value.shape[0]
171
172 # pick random test examples, with which to initialize the persistent chain
173 test_idx = rng.randint(number_of_test_samples-20)
174 persistent_vis_chain = theano.shared(test_set_x.value[test_idx:test_idx+20])
175
176 # define one step of Gibbs sampling (mf = mean-field)
177 [hid_mf, hid_sample, vis_mf, vis_sample] = self.rbm.gibbs_vhv(persistent_vis_chain)
178
179 # the sample at the end of the channel is returned by ``gibbs_1`` as
180 # its second output; note that this is computed as a binomial draw,
181 # therefore it is formed of ints (0 and 1) and therefore needs to
182 # be converted to the same dtype as ``persistent_vis_chain``
183 vis_sample = T.cast(vis_sample, dtype=theano.config.floatX)
184
185 # construct the function that implements our persistent chain
186 # we generate the "mean field" activations for plotting and the actual samples for
187 # reinitializing the state of our persistent chain
188 sample_fn = theano.function([], [vis_mf, vis_sample],
189 updates = { persistent_vis_chain:vis_sample})
190
191 # sample the RBM, plotting every `plot_every`-th sample; do this
192 # until you plot at least `n_samples`
193 n_samples = 10
194 plot_every = 1000
195
196 for idx in xrange(n_samples):
197
198 # do `plot_every` intermediate samplings of which we do not care
199 for jdx in xrange(plot_every):
200 vis_mf, vis_sample = sample_fn()
201
202 # construct image
203 image = PIL.Image.fromarray(tile_raster_images(
204 X = vis_mf,
205 img_shape = (28,28),
206 tile_shape = (10,10),
207 tile_spacing = (1,1) ) )
208
209 image.save('sample_%i_step_%i.png'%(idx,idx*jdx))
210
211
212 if __name__ == '__main__':
213 mc = ExperienceRbm()
214 mc.train()
215