Mercurial > ift6266
diff deep/crbm/mnist_crbm.py @ 339:ffbf0e41bcee
Aded code to run experiment on cluster, separate configuration from other machinery. Not tested yet.
author | fsavard |
---|---|
date | Sat, 17 Apr 2010 20:29:18 -0400 |
parents | 8d116d4a7593 |
children | 523e7b87c521 |
line wrap: on
line diff
--- a/deep/crbm/mnist_crbm.py Sat Apr 17 12:42:48 2010 -0400 +++ b/deep/crbm/mnist_crbm.py Sat Apr 17 20:29:18 2010 -0400 @@ -1,5 +1,8 @@ #!/usr/bin/python +# do this first +from mnist_config import * + import sys import os, os.path @@ -18,68 +21,44 @@ from pylearn.io.seriestables import * import tables -IMAGE_OUTPUT_DIR = 'img/' - -REDUCE_EVERY = 100 - -def filename_from_time(suffix): - import datetime - return str(datetime.datetime.now()) + suffix + ".png" +import utils -# Just a shortcut for a common case where we need a few -# related Error (float) series -def get_accumulator_series_array( \ - hdf5_file, group_name, series_names, - reduce_every, - index_names=('epoch','minibatch'), - stdout_too=True, - skip_hdf5_append=False): - all_series = [] - - hdf5_file.createGroup('/', group_name) +#def filename_from_time(suffix): +# import datetime +# return str(datetime.datetime.now()) + suffix + ".png" - other_targets = [] - if stdout_too: - other_targets = [StdoutAppendTarget()] +def jobman_entrypoint(state, channel): + # record mercurial versions of each package + pylearn.version.record_versions(state,[theano,ift6266,pylearn]) + channel.save() - for sn in series_names: - series_base = \ - ErrorSeries(error_name=sn, - table_name=sn, - hdf5_file=hdf5_file, - hdf5_group='/'+group_name, - index_names=index_names, - other_targets=other_targets, - skip_hdf5_append=skip_hdf5_append) + crbm = MnistCrbm(state) + crbm.train() - all_series.append( \ - AccumulatorSeriesWrapper( \ - base_series=series_base, - reduce_every=reduce_every)) - - ret_wrapper = SeriesArrayWrapper(all_series) - - return ret_wrapper + return channel.COMPLETE class MnistCrbm(object): - def __init__(self): - self.mnist = MNIST.full()#first_10k() + def __init__(self, state): + self.state = state + + if TEST_CONFIG: + self.mnist = MNIST.full()#first_10k() self.cp = ConvolutionParams( \ - num_filters=40, + num_filters=state.num_filters, num_input_planes=1, - height_filters=12, - width_filters=12) + height_filters=state.filter_size, + width_filters=state.filter_size) self.image_size = (28,28) - self.minibatch_size = 10 + self.minibatch_size = state.minibatch_size - self.lr = 0.01 - self.sparsity_lambda = 1.0 + self.lr = state.learning_rate + self.sparsity_lambda = state.sparsity_lambda # about 1/num_filters, so only one filter active at a time # 40 * 0.05 = ~2 filters active for any given pixel - self.sparsity_p = 0.05 + self.sparsity_p = state.sparsity_p self.crbm = CRBM( \ minibatch_size=self.minibatch_size, @@ -89,12 +68,11 @@ sparsity_lambda=self.sparsity_lambda, sparsity_p=self.sparsity_p) - self.num_epochs = 10 + self.num_epochs = state.num_epochs self.init_series() def init_series(self): - series = {} basedir = os.getcwd() @@ -103,38 +81,36 @@ cd_series_names = self.crbm.cd_return_desc series['cd'] = \ - get_accumulator_series_array( \ + utils.get_accumulator_series_array( \ h5f, 'cd', cd_series_names, REDUCE_EVERY, - stdout_too=True) + stdout_too=SERIES_STDOUT_TOO) sparsity_series_names = self.crbm.sparsity_return_desc series['sparsity'] = \ - get_accumulator_series_array( \ + utils.get_accumulator_series_array( \ h5f, 'sparsity', sparsity_series_names, REDUCE_EVERY, - stdout_too=True) + stdout_too=SERIES_STDOUT_TOO) # so first we create the names for each table, based on # position of each param in the array - params_stdout = StdoutAppendTarget("\n------\nParams") + params_stdout = [] + if SERIES_STDOUT_TOO: + params_stdout = [StdoutAppendTarget()] series['params'] = SharedParamsStatisticsWrapper( new_group_name="params", base_group="/", arrays_names=['W','b_h','b_x'], hdf5_file=h5f, index_names=('epoch','minibatch'), - other_targets=[params_stdout]) + other_targets=params_stdout) self.series = series def train(self): num_minibatches = len(self.mnist.train.x) / self.minibatch_size - print_every = 1000 - visualize_every = 5000 - gibbs_steps_from_random = 1000 - for epoch in xrange(self.num_epochs): for mb_index in xrange(num_minibatches): mb_x = self.mnist.train.x \ @@ -158,13 +134,22 @@ self.series['params'].append( \ (epoch, mb_index), self.crbm.params) - if total_idx % visualize_every == 0: + if total_idx % VISUALIZE_EVERY == 0: self.visualize_gibbs_result(\ - mb_x, gibbs_steps_from_random) - self.visualize_gibbs_result(mb_x, 1) - self.visualize_filters() + mb_x, GIBBS_STEPS_IN_VIZ_CHAIN, + "gibbs_chain_"+str(epoch)+"_"+str(mb_index)) + self.visualize_gibbs_result(mb_x, 1, + "gibbs_1_"+str(epoch)+"_"+str(mb_index)) + self.visualize_filters( + "filters_"+str(epoch)+"_"+str(mb_index)) + if TEST_CONFIG: + # do a single epoch for cluster tests config + break + + if SAVE_PARAMS: + utils.save_params(self.crbm.params, "params.pkl") - def visualize_gibbs_result(self, start_x, gibbs_steps): + def visualize_gibbs_result(self, start_x, gibbs_steps, filename): # Run minibatch_size chains for gibbs_steps x_samples = None if not start_x is None: @@ -176,15 +161,14 @@ tile = tile_raster_images(x_samples, self.image_size, (1, self.minibatch_size), output_pixel_vals=True) - filepath = os.path.join(IMAGE_OUTPUT_DIR, - filename_from_time("gibbs")) + filepath = os.path.join(IMAGE_OUTPUT_DIR, filename+".png") img = Image.fromarray(tile) img.save(filepath) print "Result of running Gibbs", \ gibbs_steps, "times outputed to", filepath - def visualize_filters(self): + def visualize_filters(self, filename): cp = self.cp # filter size @@ -198,18 +182,26 @@ tile = tile_raster_images(filters_flattened, fsz, tile_shape, output_pixel_vals=True) - filepath = os.path.join(IMAGE_OUTPUT_DIR, - filename_from_time("filters")) + filepath = os.path.join(IMAGE_OUTPUT_DIR, filename+".png") img = Image.fromarray(tile) img.save(filepath) print "Filters (as images) outputed to", filepath - print "b_h is", self.crbm.b_h.value - if __name__ == '__main__': - mc = MnistCrbm() - mc.train() + args = sys.argv[1:] + if len(args) == 0: + print "Bad usage" + elif args[0] == 'jobman_insert': + utils.jobman_insert_job_vals(JOBDB, EXPERIMENT_PATH, JOB_VALS) + elif args[0] == 'test_jobman_entrypoint': + chanmock = DD({'COMPLETE':0,'save':(lambda:None)}) + jobman_entrypoint(DEFAULT_STATE, chanmock) + elif args[0] == 'run_default': + mc = MnistCrbm(DEFAULT_STATE) + mc.train() + +