Mercurial > ift6266
comparison deep/rbm/mnistrbm.py @ 348:45156cbf6722
training an rbm using cd or pcd
author | goldfinger |
---|---|
date | Mon, 19 Apr 2010 08:17:45 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
347:9685e9d94cc4 | 348:45156cbf6722 |
---|---|
1 import sys | |
2 import os, os.path | |
3 | |
4 import numpy as N | |
5 | |
6 import theano | |
7 import theano.tensor as T | |
8 | |
9 from crbm import CRBM, ConvolutionParams | |
10 | |
11 from pylearn.datasets import MNIST | |
12 from pylearn.io.image_tiling import tile_raster_images | |
13 | |
14 import Image | |
15 | |
16 from pylearn.io.seriestables import * | |
17 import tables | |
18 | |
19 IMAGE_OUTPUT_DIR = 'img/' | |
20 | |
21 REDUCE_EVERY = 100 | |
22 | |
23 def filename_from_time(suffix): | |
24 import datetime | |
25 return str(datetime.datetime.now()) + suffix + ".png" | |
26 | |
27 # Just a shortcut for a common case where we need a few | |
28 # related Error (float) series | |
29 | |
30 def get_accumulator_series_array( \ | |
31 hdf5_file, group_name, series_names, | |
32 reduce_every, | |
33 index_names=('epoch','minibatch'), | |
34 stdout_too=True, | |
35 skip_hdf5_append=False): | |
36 all_series = [] | |
37 | |
38 hdf5_file.createGroup('/', group_name) | |
39 | |
40 other_targets = [] | |
41 if stdout_too: | |
42 other_targets = [StdoutAppendTarget()] | |
43 | |
44 for sn in series_names: | |
45 series_base = \ | |
46 ErrorSeries(error_name=sn, | |
47 table_name=sn, | |
48 hdf5_file=hdf5_file, | |
49 hdf5_group='/'+group_name, | |
50 index_names=index_names, | |
51 other_targets=other_targets, | |
52 skip_hdf5_append=skip_hdf5_append) | |
53 | |
54 all_series.append( \ | |
55 AccumulatorSeriesWrapper( \ | |
56 base_series=series_base, | |
57 reduce_every=reduce_every)) | |
58 | |
59 ret_wrapper = SeriesArrayWrapper(all_series) | |
60 | |
61 return ret_wrapper | |
62 | |
63 class ExperienceRbm(object): | |
64 def __init__(self): | |
65 self.mnist = MNIST.full()#first_10k() | |
66 | |
67 | |
68 datasets = load_data(dataset) | |
69 | |
70 train_set_x, train_set_y = datasets[0] | |
71 test_set_x , test_set_y = datasets[2] | |
72 | |
73 | |
74 batch_size = 100 # size of the minibatch | |
75 | |
76 # compute number of minibatches for training, validation and testing | |
77 n_train_batches = train_set_x.value.shape[0] / batch_size | |
78 | |
79 # allocate symbolic variables for the data | |
80 index = T.lscalar() # index to a [mini]batch | |
81 x = T.matrix('x') # the data is presented as rasterized images | |
82 | |
83 rng = numpy.random.RandomState(123) | |
84 theano_rng = RandomStreams( rng.randint(2**30)) | |
85 | |
86 # initialize storage fot the persistent chain (state = hidden layer of chain) | |
87 persistent_chain = theano.shared(numpy.zeros((batch_size, 500))) | |
88 | |
89 # construct the RBM class | |
90 self.rbm = RBM( input = x, n_visible=28*28, \ | |
91 n_hidden = 500,numpy_rng = rng, theano_rng = theano_rng) | |
92 | |
93 # get the cost and the gradient corresponding to one step of CD | |
94 | |
95 | |
96 self.init_series() | |
97 | |
98 def init_series(self): | |
99 | |
100 series = {} | |
101 | |
102 basedir = os.getcwd() | |
103 | |
104 h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w") | |
105 | |
106 cd_series_names = self.rbm.cd_return_desc | |
107 series['cd'] = \ | |
108 get_accumulator_series_array( \ | |
109 h5f, 'cd', cd_series_names, | |
110 REDUCE_EVERY, | |
111 stdout_too=True) | |
112 | |
113 | |
114 | |
115 # so first we create the names for each table, based on | |
116 # position of each param in the array | |
117 params_stdout = StdoutAppendTarget("\n------\nParams") | |
118 series['params'] = SharedParamsStatisticsWrapper( | |
119 new_group_name="params", | |
120 base_group="/", | |
121 arrays_names=['W','b_h','b_x'], | |
122 hdf5_file=h5f, | |
123 index_names=('epoch','minibatch'), | |
124 other_targets=[params_stdout]) | |
125 | |
126 self.series = series | |
127 | |
128 def train(self, persistent, learning_rate): | |
129 | |
130 training_epochs = 15 | |
131 | |
132 #get the cost and the gradient corresponding to one step of CD | |
133 if persistant: | |
134 persistent_chain = theano.shared(numpy.zeros((batch_size, self.rbm.n_hidden))) | |
135 cost, updates = self.rbm.cd(lr=learning_rate, persistent=persistent_chain) | |
136 | |
137 else: | |
138 cost, updates = self.rbm.cd(lr=learning_rate) | |
139 | |
140 dirname = 'lr=%.5f'%self.rbm.learning_rate | |
141 os.makedirs(dirname) | |
142 os.chdir(dirname) | |
143 | |
144 # the purpose of train_rbm is solely to update the RBM parameters | |
145 train_rbm = theano.function([index], cost, | |
146 updates = updates, | |
147 givens = { x: train_set_x[index*batch_size:(index+1)*batch_size]}) | |
148 | |
149 plotting_time = 0. | |
150 start_time = time.clock() | |
151 | |
152 | |
153 # go through training epochs | |
154 for epoch in xrange(training_epochs): | |
155 | |
156 # go through the training set | |
157 mean_cost = [] | |
158 for batch_index in xrange(n_train_batches): | |
159 mean_cost += [train_rbm(batch_index)] | |
160 | |
161 | |
162 pretraining_time = (end_time - start_time) | |
163 | |
164 | |
165 | |
166 | |
167 def sample_from_rbm(self, gibbs_steps, test_set_x): | |
168 | |
169 # find out the number of test samples | |
170 number_of_test_samples = test_set_x.value.shape[0] | |
171 | |
172 # pick random test examples, with which to initialize the persistent chain | |
173 test_idx = rng.randint(number_of_test_samples-20) | |
174 persistent_vis_chain = theano.shared(test_set_x.value[test_idx:test_idx+20]) | |
175 | |
176 # define one step of Gibbs sampling (mf = mean-field) | |
177 [hid_mf, hid_sample, vis_mf, vis_sample] = self.rbm.gibbs_vhv(persistent_vis_chain) | |
178 | |
179 # the sample at the end of the channel is returned by ``gibbs_1`` as | |
180 # its second output; note that this is computed as a binomial draw, | |
181 # therefore it is formed of ints (0 and 1) and therefore needs to | |
182 # be converted to the same dtype as ``persistent_vis_chain`` | |
183 vis_sample = T.cast(vis_sample, dtype=theano.config.floatX) | |
184 | |
185 # construct the function that implements our persistent chain | |
186 # we generate the "mean field" activations for plotting and the actual samples for | |
187 # reinitializing the state of our persistent chain | |
188 sample_fn = theano.function([], [vis_mf, vis_sample], | |
189 updates = { persistent_vis_chain:vis_sample}) | |
190 | |
191 # sample the RBM, plotting every `plot_every`-th sample; do this | |
192 # until you plot at least `n_samples` | |
193 n_samples = 10 | |
194 plot_every = 1000 | |
195 | |
196 for idx in xrange(n_samples): | |
197 | |
198 # do `plot_every` intermediate samplings of which we do not care | |
199 for jdx in xrange(plot_every): | |
200 vis_mf, vis_sample = sample_fn() | |
201 | |
202 # construct image | |
203 image = PIL.Image.fromarray(tile_raster_images( | |
204 X = vis_mf, | |
205 img_shape = (28,28), | |
206 tile_shape = (10,10), | |
207 tile_spacing = (1,1) ) ) | |
208 | |
209 image.save('sample_%i_step_%i.png'%(idx,idx*jdx)) | |
210 | |
211 | |
212 if __name__ == '__main__': | |
213 mc = ExperienceRbm() | |
214 mc.train() | |
215 |