Mercurial > ift6266
diff deep/rbm/mnistrbm.py @ 348:45156cbf6722
training an rbm using cd or pcd
author | goldfinger |
---|---|
date | Mon, 19 Apr 2010 08:17:45 -0400 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/deep/rbm/mnistrbm.py Mon Apr 19 08:17:45 2010 -0400 @@ -0,0 +1,215 @@ +import sys +import os, os.path + +import numpy as N + +import theano +import theano.tensor as T + +from crbm import CRBM, ConvolutionParams + +from pylearn.datasets import MNIST +from pylearn.io.image_tiling import tile_raster_images + +import Image + +from pylearn.io.seriestables import * +import tables + +IMAGE_OUTPUT_DIR = 'img/' + +REDUCE_EVERY = 100 + +def filename_from_time(suffix): + import datetime + return str(datetime.datetime.now()) + suffix + ".png" + +# Just a shortcut for a common case where we need a few +# related Error (float) series + +def get_accumulator_series_array( \ + hdf5_file, group_name, series_names, + reduce_every, + index_names=('epoch','minibatch'), + stdout_too=True, + skip_hdf5_append=False): + all_series = [] + + hdf5_file.createGroup('/', group_name) + + other_targets = [] + if stdout_too: + other_targets = [StdoutAppendTarget()] + + for sn in series_names: + series_base = \ + ErrorSeries(error_name=sn, + table_name=sn, + hdf5_file=hdf5_file, + hdf5_group='/'+group_name, + index_names=index_names, + other_targets=other_targets, + skip_hdf5_append=skip_hdf5_append) + + all_series.append( \ + AccumulatorSeriesWrapper( \ + base_series=series_base, + reduce_every=reduce_every)) + + ret_wrapper = SeriesArrayWrapper(all_series) + + return ret_wrapper + +class ExperienceRbm(object): + def __init__(self): + self.mnist = MNIST.full()#first_10k() + + + datasets = load_data(dataset) + + train_set_x, train_set_y = datasets[0] + test_set_x , test_set_y = datasets[2] + + + batch_size = 100 # size of the minibatch + + # compute number of minibatches for training, validation and testing + n_train_batches = train_set_x.value.shape[0] / batch_size + + # allocate symbolic variables for the data + index = T.lscalar() # index to a [mini]batch + x = T.matrix('x') # the data is presented as rasterized images + + rng = numpy.random.RandomState(123) + theano_rng = RandomStreams( rng.randint(2**30)) + + # initialize storage fot the persistent chain (state = hidden layer of chain) + persistent_chain = theano.shared(numpy.zeros((batch_size, 500))) + + # construct the RBM class + self.rbm = RBM( input = x, n_visible=28*28, \ + n_hidden = 500,numpy_rng = rng, theano_rng = theano_rng) + + # get the cost and the gradient corresponding to one step of CD + + + self.init_series() + + def init_series(self): + + series = {} + + basedir = os.getcwd() + + h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w") + + cd_series_names = self.rbm.cd_return_desc + series['cd'] = \ + get_accumulator_series_array( \ + h5f, 'cd', cd_series_names, + REDUCE_EVERY, + stdout_too=True) + + + + # so first we create the names for each table, based on + # position of each param in the array + params_stdout = StdoutAppendTarget("\n------\nParams") + series['params'] = SharedParamsStatisticsWrapper( + new_group_name="params", + base_group="/", + arrays_names=['W','b_h','b_x'], + hdf5_file=h5f, + index_names=('epoch','minibatch'), + other_targets=[params_stdout]) + + self.series = series + + def train(self, persistent, learning_rate): + + training_epochs = 15 + + #get the cost and the gradient corresponding to one step of CD + if persistant: + persistent_chain = theano.shared(numpy.zeros((batch_size, self.rbm.n_hidden))) + cost, updates = self.rbm.cd(lr=learning_rate, persistent=persistent_chain) + + else: + cost, updates = self.rbm.cd(lr=learning_rate) + + dirname = 'lr=%.5f'%self.rbm.learning_rate + os.makedirs(dirname) + os.chdir(dirname) + + # the purpose of train_rbm is solely to update the RBM parameters + train_rbm = theano.function([index], cost, + updates = updates, + givens = { x: train_set_x[index*batch_size:(index+1)*batch_size]}) + + plotting_time = 0. + start_time = time.clock() + + + # go through training epochs + for epoch in xrange(training_epochs): + + # go through the training set + mean_cost = [] + for batch_index in xrange(n_train_batches): + mean_cost += [train_rbm(batch_index)] + + + pretraining_time = (end_time - start_time) + + + + + def sample_from_rbm(self, gibbs_steps, test_set_x): + + # find out the number of test samples + number_of_test_samples = test_set_x.value.shape[0] + + # pick random test examples, with which to initialize the persistent chain + test_idx = rng.randint(number_of_test_samples-20) + persistent_vis_chain = theano.shared(test_set_x.value[test_idx:test_idx+20]) + + # define one step of Gibbs sampling (mf = mean-field) + [hid_mf, hid_sample, vis_mf, vis_sample] = self.rbm.gibbs_vhv(persistent_vis_chain) + + # the sample at the end of the channel is returned by ``gibbs_1`` as + # its second output; note that this is computed as a binomial draw, + # therefore it is formed of ints (0 and 1) and therefore needs to + # be converted to the same dtype as ``persistent_vis_chain`` + vis_sample = T.cast(vis_sample, dtype=theano.config.floatX) + + # construct the function that implements our persistent chain + # we generate the "mean field" activations for plotting and the actual samples for + # reinitializing the state of our persistent chain + sample_fn = theano.function([], [vis_mf, vis_sample], + updates = { persistent_vis_chain:vis_sample}) + + # sample the RBM, plotting every `plot_every`-th sample; do this + # until you plot at least `n_samples` + n_samples = 10 + plot_every = 1000 + + for idx in xrange(n_samples): + + # do `plot_every` intermediate samplings of which we do not care + for jdx in xrange(plot_every): + vis_mf, vis_sample = sample_fn() + + # construct image + image = PIL.Image.fromarray(tile_raster_images( + X = vis_mf, + img_shape = (28,28), + tile_shape = (10,10), + tile_spacing = (1,1) ) ) + + image.save('sample_%i_step_%i.png'%(idx,idx*jdx)) + + +if __name__ == '__main__': + mc = ExperienceRbm() + mc.train() +