diff deep/convolutional_dae/run_exp.py @ 291:7d1fa2d7721c

Split out the run_exp method.
author Arnaud Bergeron <abergeron@gmail.com>
date Fri, 26 Mar 2010 18:35:23 -0400
parents
children 8108d271c30c
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/deep/convolutional_dae/run_exp.py	Fri Mar 26 18:35:23 2010 -0400
@@ -0,0 +1,71 @@
+from ift6266.deep.convolutional_dae.scdae import *
+
+class dumb(object):
+    def save(self):
+        pass
+
+def go(state, channel):
+    from ift6266 import datasets
+    from ift6266.deep.convolutional_dae.sgd_opt import sgd_opt
+    import pylearn, theano, ift6266
+    import pylearn.version
+
+    # params: bsize, pretrain_lr, train_lr, nfilts1, nfilts2, nftils3, nfilts4
+    #         pretrain_rounds, noise, mlp_sz
+
+    pylearn.version.record_versions(state, [theano, ift6266, pylearn])
+    # TODO: maybe record pynnet version?
+    channel.save()
+
+    dset = datasets.nist_all(1000)
+
+    nfilts = []
+    if state.nfilts1 != 0:
+        nfilts.append(state.nfilts1)
+        if state.nfilts2 != 0:
+            nfilts.append(state.nfilts2)
+            if state.nfilts3 != 0:
+                nfilts.append(state.nfilts3)
+                if state.nfilts4 != 0:
+                    nfilts.append(state.nfilts4)
+
+    fsizes = [(5,5)]*len(nfilts)
+    subs = [(2,2)]*len(nfilts)
+    noise = [state.noise]*len(nfilts)
+
+    pretrain_funcs, trainf, evalf, net = build_funcs(
+        img_size=(32, 32),
+        batch_size=state.bsize,
+        filter_sizes=fsizes,
+        num_filters=nfilts,
+        subs=subs,
+        noise=noise,
+        mlp_sizes=[state.mlp_sz],
+        out_size=62,
+        dtype=numpy.float32,
+        pretrain_lr=state.pretrain_lr,
+        train_lr=state.train_lr)
+
+    pretrain_fs, train, valid, test = massage_funcs(
+        state.bsize, dset, pretrain_funcs, trainf, evalf)
+
+    series = create_series()
+
+    do_pretrain(pretrain_fs, state.pretrain_rounds, series['recons_error'])
+
+    sgd_opt(train, valid, test, training_epochs=100000, patience=10000,
+            patience_increase=2., improvement_threshold=0.995,
+            validation_frequency=2500, series=series, net=net)
+
+if __name__ == '__main__':
+    st = dumb()
+    st.bsize = 100
+    st.pretrain_lr = 0.01
+    st.train_lr = 0.1
+    st.nfilts1 = 4
+    st.nfilts2 = 4
+    st.nfilts3 = 0
+    st.pretrain_rounds = 500
+    st.noise=0.2
+    st.mlp_sz = 500
+    go(st, dumb())