diff deep/crbm/mnist_crbm.py @ 339:ffbf0e41bcee

Aded code to run experiment on cluster, separate configuration from other machinery. Not tested yet.
author fsavard
date Sat, 17 Apr 2010 20:29:18 -0400
parents 8d116d4a7593
children 523e7b87c521
line wrap: on
line diff
--- a/deep/crbm/mnist_crbm.py	Sat Apr 17 12:42:48 2010 -0400
+++ b/deep/crbm/mnist_crbm.py	Sat Apr 17 20:29:18 2010 -0400
@@ -1,5 +1,8 @@
 #!/usr/bin/python
 
+# do this first
+from mnist_config import *
+
 import sys
 import os, os.path
 
@@ -18,68 +21,44 @@
 from pylearn.io.seriestables import *
 import tables
 
-IMAGE_OUTPUT_DIR = 'img/'
-
-REDUCE_EVERY = 100
-
-def filename_from_time(suffix):
-    import datetime
-    return str(datetime.datetime.now()) + suffix + ".png"
+import utils
 
-# Just a shortcut for a common case where we need a few
-# related Error (float) series
-def get_accumulator_series_array( \
-                hdf5_file, group_name, series_names, 
-                reduce_every,
-                index_names=('epoch','minibatch'),
-                stdout_too=True,
-                skip_hdf5_append=False):
-    all_series = []
-
-    hdf5_file.createGroup('/', group_name)
+#def filename_from_time(suffix):
+#    import datetime
+#    return str(datetime.datetime.now()) + suffix + ".png"
 
-    other_targets = []
-    if stdout_too:
-        other_targets = [StdoutAppendTarget()]
+def jobman_entrypoint(state, channel):
+    # record mercurial versions of each package
+    pylearn.version.record_versions(state,[theano,ift6266,pylearn])
+    channel.save()
 
-    for sn in series_names:
-        series_base = \
-            ErrorSeries(error_name=sn,
-                table_name=sn,
-                hdf5_file=hdf5_file,
-                hdf5_group='/'+group_name,
-                index_names=index_names,
-                other_targets=other_targets,
-                skip_hdf5_append=skip_hdf5_append)
+    crbm = MnistCrbm(state)
+    crbm.train()
 
-        all_series.append( \
-            AccumulatorSeriesWrapper( \
-                    base_series=series_base,
-                    reduce_every=reduce_every))
-
-    ret_wrapper = SeriesArrayWrapper(all_series)
-
-    return ret_wrapper
+    return channel.COMPLETE
 
 class MnistCrbm(object):
-    def __init__(self):
-        self.mnist = MNIST.full()#first_10k()
+    def __init__(self, state):
+        self.state = state
+
+        if TEST_CONFIG:
+            self.mnist = MNIST.full()#first_10k()
 
         self.cp = ConvolutionParams( \
-                    num_filters=40,
+                    num_filters=state.num_filters,
                     num_input_planes=1,
-                    height_filters=12,
-                    width_filters=12)
+                    height_filters=state.filter_size,
+                    width_filters=state.filter_size)
 
         self.image_size = (28,28)
 
-        self.minibatch_size = 10
+        self.minibatch_size = state.minibatch_size
 
-        self.lr = 0.01
-        self.sparsity_lambda = 1.0
+        self.lr = state.learning_rate
+        self.sparsity_lambda = state.sparsity_lambda
         # about 1/num_filters, so only one filter active at a time
         # 40 * 0.05 = ~2 filters active for any given pixel
-        self.sparsity_p = 0.05
+        self.sparsity_p = state.sparsity_p
 
         self.crbm = CRBM( \
                     minibatch_size=self.minibatch_size,
@@ -89,12 +68,11 @@
                     sparsity_lambda=self.sparsity_lambda,
                     sparsity_p=self.sparsity_p)
         
-        self.num_epochs = 10
+        self.num_epochs = state.num_epochs
 
         self.init_series()
  
     def init_series(self):
-
         series = {}
 
         basedir = os.getcwd()
@@ -103,38 +81,36 @@
 
         cd_series_names = self.crbm.cd_return_desc
         series['cd'] = \
-            get_accumulator_series_array( \
+            utils.get_accumulator_series_array( \
                 h5f, 'cd', cd_series_names,
                 REDUCE_EVERY,
-                stdout_too=True)
+                stdout_too=SERIES_STDOUT_TOO)
 
         sparsity_series_names = self.crbm.sparsity_return_desc
         series['sparsity'] = \
-            get_accumulator_series_array( \
+            utils.get_accumulator_series_array( \
                 h5f, 'sparsity', sparsity_series_names,
                 REDUCE_EVERY,
-                stdout_too=True)
+                stdout_too=SERIES_STDOUT_TOO)
 
         # so first we create the names for each table, based on 
         # position of each param in the array
-        params_stdout = StdoutAppendTarget("\n------\nParams")
+        params_stdout = []
+        if SERIES_STDOUT_TOO:
+            params_stdout = [StdoutAppendTarget()]
         series['params'] = SharedParamsStatisticsWrapper(
                             new_group_name="params",
                             base_group="/",
                             arrays_names=['W','b_h','b_x'],
                             hdf5_file=h5f,
                             index_names=('epoch','minibatch'),
-                            other_targets=[params_stdout])
+                            other_targets=params_stdout)
 
         self.series = series
 
     def train(self):
         num_minibatches = len(self.mnist.train.x) / self.minibatch_size
 
-        print_every = 1000
-        visualize_every = 5000
-        gibbs_steps_from_random = 1000
-
         for epoch in xrange(self.num_epochs):
             for mb_index in xrange(num_minibatches):
                 mb_x = self.mnist.train.x \
@@ -158,13 +134,22 @@
                     self.series['params'].append( \
                         (epoch, mb_index), self.crbm.params)
 
-                if total_idx % visualize_every == 0:
+                if total_idx % VISUALIZE_EVERY == 0:
                     self.visualize_gibbs_result(\
-                        mb_x, gibbs_steps_from_random)
-                    self.visualize_gibbs_result(mb_x, 1)
-                    self.visualize_filters()
+                        mb_x, GIBBS_STEPS_IN_VIZ_CHAIN,
+                        "gibbs_chain_"+str(epoch)+"_"+str(mb_index))
+                    self.visualize_gibbs_result(mb_x, 1,
+                        "gibbs_1_"+str(epoch)+"_"+str(mb_index))
+                    self.visualize_filters(
+                        "filters_"+str(epoch)+"_"+str(mb_index))
+            if TEST_CONFIG:
+                # do a single epoch for cluster tests config
+                break
+
+        if SAVE_PARAMS:
+            utils.save_params(self.crbm.params, "params.pkl")
     
-    def visualize_gibbs_result(self, start_x, gibbs_steps):
+    def visualize_gibbs_result(self, start_x, gibbs_steps, filename):
         # Run minibatch_size chains for gibbs_steps
         x_samples = None
         if not start_x is None:
@@ -176,15 +161,14 @@
         tile = tile_raster_images(x_samples, self.image_size,
                     (1, self.minibatch_size), output_pixel_vals=True)
 
-        filepath = os.path.join(IMAGE_OUTPUT_DIR,
-                    filename_from_time("gibbs"))
+        filepath = os.path.join(IMAGE_OUTPUT_DIR, filename+".png")
         img = Image.fromarray(tile)
         img.save(filepath)
 
         print "Result of running Gibbs", \
                 gibbs_steps, "times outputed to", filepath
 
-    def visualize_filters(self):
+    def visualize_filters(self, filename):
         cp = self.cp
 
         # filter size
@@ -198,18 +182,26 @@
         tile = tile_raster_images(filters_flattened, fsz, 
                                     tile_shape, output_pixel_vals=True)
 
-        filepath = os.path.join(IMAGE_OUTPUT_DIR,
-                        filename_from_time("filters"))
+        filepath = os.path.join(IMAGE_OUTPUT_DIR, filename+".png")
         img = Image.fromarray(tile)
         img.save(filepath)
 
         print "Filters (as images) outputed to", filepath
-        print "b_h is", self.crbm.b_h.value
-
 
 
 
 if __name__ == '__main__':
-    mc = MnistCrbm()
-    mc.train()
+    args = sys.argv[1:]
 
+    if len(args) == 0:
+        print "Bad usage"
+    elif args[0] == 'jobman_insert':
+        utils.jobman_insert_job_vals(JOBDB, EXPERIMENT_PATH, JOB_VALS)
+    elif args[0] == 'test_jobman_entrypoint':
+        chanmock = DD({'COMPLETE':0,'save':(lambda:None)})
+        jobman_entrypoint(DEFAULT_STATE, chanmock)
+    elif args[0] == 'run_default':
+        mc = MnistCrbm(DEFAULT_STATE)
+        mc.train()
+
+