Mercurial > ift6266
comparison deep/crbm/mnist_crbm.py @ 340:523e7b87c521
Corrected a few bugs, no new features. Supposedly ready to run on cluster.
author | fsavard |
---|---|
date | Sun, 18 Apr 2010 11:39:24 -0400 |
parents | ffbf0e41bcee |
children | b25ad1670ff7 |
comparison
equal
deleted
inserted
replaced
339:ffbf0e41bcee | 340:523e7b87c521 |
---|---|
1 #!/usr/bin/python | 1 #!/usr/bin/python |
2 | 2 |
3 # do this first | 3 import sys |
4 | |
5 # do this before importing custom modules | |
4 from mnist_config import * | 6 from mnist_config import * |
5 | 7 |
6 import sys | 8 if not (len(sys.argv) > 1 and sys.argv[1] in \ |
9 ('test_jobman_entrypoint', 'run_local')): | |
10 # in those cases don't use isolated code, use dev code | |
11 isolate_code() | |
12 | |
7 import os, os.path | 13 import os, os.path |
8 | 14 |
9 import numpy as N | 15 import numpy as N |
10 | 16 |
11 import theano | 17 import theano |
12 import theano.tensor as T | 18 import theano.tensor as T |
13 | 19 |
14 from crbm import CRBM, ConvolutionParams | 20 from crbm import CRBM, ConvolutionParams |
15 | 21 |
22 import pylearn, pylearn.version | |
16 from pylearn.datasets import MNIST | 23 from pylearn.datasets import MNIST |
17 from pylearn.io.image_tiling import tile_raster_images | 24 from pylearn.io.image_tiling import tile_raster_images |
18 | 25 |
19 import Image | 26 import Image |
20 | 27 |
21 from pylearn.io.seriestables import * | 28 from pylearn.io.seriestables import * |
22 import tables | 29 import tables |
23 | 30 |
31 import ift6266 | |
32 | |
24 import utils | 33 import utils |
34 | |
35 if not os.path.exists(IMAGE_OUTPUT_DIR): | |
36 os.mkdir(IMAGE_OUTPUT_DIR) | |
37 elif os.path.isfile(IMAGE_OUTPUT_DIR): | |
38 print "IMAGE_OUTPUT_DIR is not a directory!" | |
39 sys.exit(1) | |
25 | 40 |
26 #def filename_from_time(suffix): | 41 #def filename_from_time(suffix): |
27 # import datetime | 42 # import datetime |
28 # return str(datetime.datetime.now()) + suffix + ".png" | 43 # return str(datetime.datetime.now()) + suffix + ".png" |
29 | 44 |
40 class MnistCrbm(object): | 55 class MnistCrbm(object): |
41 def __init__(self, state): | 56 def __init__(self, state): |
42 self.state = state | 57 self.state = state |
43 | 58 |
44 if TEST_CONFIG: | 59 if TEST_CONFIG: |
60 self.mnist = MNIST.first_1k() | |
61 print "Test config, so loaded MNIST first 1000" | |
62 else: | |
45 self.mnist = MNIST.full()#first_10k() | 63 self.mnist = MNIST.full()#first_10k() |
64 print "Loaded MNIST full" | |
46 | 65 |
47 self.cp = ConvolutionParams( \ | 66 self.cp = ConvolutionParams( \ |
48 num_filters=state.num_filters, | 67 num_filters=state.num_filters, |
49 num_input_planes=1, | 68 num_input_planes=1, |
50 height_filters=state.filter_size, | 69 height_filters=state.filter_size, |