Mercurial > ift6266
diff deep/crbm/mnist_crbm.py @ 337:8d116d4a7593
Added convolutional RBM (ala Lee09) code, imported from my working dir elsewhere. Seems to work for one layer. No subsampling yet.
author | fsavard |
---|---|
date | Fri, 16 Apr 2010 16:05:55 -0400 |
parents | |
children | ffbf0e41bcee |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/deep/crbm/mnist_crbm.py Fri Apr 16 16:05:55 2010 -0400 @@ -0,0 +1,215 @@ +#!/usr/bin/python + +import sys +import os, os.path + +import numpy as N + +import theano +import theano.tensor as T + +from crbm import CRBM, ConvolutionParams + +from pylearn.datasets import MNIST +from pylearn.io.image_tiling import tile_raster_images + +import Image + +from pylearn.io.seriestables import * +import tables + +IMAGE_OUTPUT_DIR = 'img/' + +REDUCE_EVERY = 100 + +def filename_from_time(suffix): + import datetime + return str(datetime.datetime.now()) + suffix + ".png" + +# Just a shortcut for a common case where we need a few +# related Error (float) series +def get_accumulator_series_array( \ + hdf5_file, group_name, series_names, + reduce_every, + index_names=('epoch','minibatch'), + stdout_too=True, + skip_hdf5_append=False): + all_series = [] + + hdf5_file.createGroup('/', group_name) + + other_targets = [] + if stdout_too: + other_targets = [StdoutAppendTarget()] + + for sn in series_names: + series_base = \ + ErrorSeries(error_name=sn, + table_name=sn, + hdf5_file=hdf5_file, + hdf5_group='/'+group_name, + index_names=index_names, + other_targets=other_targets, + skip_hdf5_append=skip_hdf5_append) + + all_series.append( \ + AccumulatorSeriesWrapper( \ + base_series=series_base, + reduce_every=reduce_every)) + + ret_wrapper = SeriesArrayWrapper(all_series) + + return ret_wrapper + +class MnistCrbm(object): + def __init__(self): + self.mnist = MNIST.full()#first_10k() + + self.cp = ConvolutionParams( \ + num_filters=40, + num_input_planes=1, + height_filters=12, + width_filters=12) + + self.image_size = (28,28) + + self.minibatch_size = 10 + + self.lr = 0.01 + self.sparsity_lambda = 1.0 + # about 1/num_filters, so only one filter active at a time + # 40 * 0.05 = ~2 filters active for any given pixel + self.sparsity_p = 0.05 + + self.crbm = CRBM( \ + minibatch_size=self.minibatch_size, + image_size=self.image_size, + conv_params=self.cp, + learning_rate=self.lr, + sparsity_lambda=self.sparsity_lambda, + sparsity_p=self.sparsity_p) + + self.num_epochs = 10 + + self.init_series() + + def init_series(self): + + series = {} + + basedir = os.getcwd() + + h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w") + + cd_series_names = self.crbm.cd_return_desc + series['cd'] = \ + get_accumulator_series_array( \ + h5f, 'cd', cd_series_names, + REDUCE_EVERY, + stdout_too=True) + + sparsity_series_names = self.crbm.sparsity_return_desc + series['sparsity'] = \ + get_accumulator_series_array( \ + h5f, 'sparsity', sparsity_series_names, + REDUCE_EVERY, + stdout_too=True) + + # so first we create the names for each table, based on + # position of each param in the array + params_stdout = StdoutAppendTarget("\n------\nParams") + series['params'] = SharedParamsStatisticsWrapper( + new_group_name="params", + base_group="/", + arrays_names=['W','b_h','b_x'], + hdf5_file=h5f, + index_names=('epoch','minibatch'), + other_targets=[params_stdout]) + + self.series = series + + def train(self): + num_minibatches = len(self.mnist.train.x) / self.minibatch_size + + print_every = 1000 + visualize_every = 5000 + gibbs_steps_from_random = 1000 + + for epoch in xrange(self.num_epochs): + for mb_index in xrange(num_minibatches): + mb_x = self.mnist.train.x \ + [mb_index : mb_index+self.minibatch_size] + mb_x = mb_x.reshape((self.minibatch_size, 1, 28, 28)) + + #E_h = crbm.E_h_given_x_func(mb_x) + #print "Shape of E_h", E_h.shape + + cd_return = self.crbm.CD_step(mb_x) + sp_return = self.crbm.sparsity_step(mb_x) + + self.series['cd'].append( \ + (epoch, mb_index), cd_return) + self.series['sparsity'].append( \ + (epoch, mb_index), sp_return) + + total_idx = epoch*num_minibatches + mb_index + + if (total_idx+1) % REDUCE_EVERY == 0: + self.series['params'].append( \ + (epoch, mb_index), self.crbm.params) + + if total_idx % visualize_every == 0: + self.visualize_gibbs_result(\ + mb_x, gibbs_steps_from_random) + self.visualize_gibbs_result(mb_x, 1) + self.visualize_filters() + + def visualize_gibbs_result(self, start_x, gibbs_steps): + # Run minibatch_size chains for gibbs_steps + x_samples = None + if not start_x is None: + x_samples = self.crbm.gibbs_samples_from(start_x, gibbs_steps) + else: + x_samples = self.crbm.random_gibbs_samples(gibbs_steps) + x_samples = x_samples.reshape((self.minibatch_size, 28*28)) + + tile = tile_raster_images(x_samples, self.image_size, + (1, self.minibatch_size), output_pixel_vals=True) + + filepath = os.path.join(IMAGE_OUTPUT_DIR, + filename_from_time("gibbs")) + img = Image.fromarray(tile) + img.save(filepath) + + print "Result of running Gibbs", \ + gibbs_steps, "times outputed to", filepath + + def visualize_filters(self): + cp = self.cp + + # filter size + fsz = (cp.height_filters, cp.width_filters) + tile_shape = (cp.num_filters, cp.num_input_planes) + + filters_flattened = self.crbm.W.value.reshape( + (tile_shape[0]*tile_shape[1], + fsz[0]*fsz[1])) + + tile = tile_raster_images(filters_flattened, fsz, + tile_shape, output_pixel_vals=True) + + filepath = os.path.join(IMAGE_OUTPUT_DIR, + filename_from_time("filters")) + img = Image.fromarray(tile) + img.save(filepath) + + print "Filters (as images) outputed to", filepath + print "b_h is", self.crbm.b_h.value + + + + +if __name__ == '__main__': + mc = MnistCrbm() + mc.train() +