comparison deep/crbm/mnist_crbm.py @ 340:523e7b87c521

Corrected a few bugs, no new features. Supposedly ready to run on cluster.
author fsavard
date Sun, 18 Apr 2010 11:39:24 -0400
parents ffbf0e41bcee
children b25ad1670ff7
comparison
equal deleted inserted replaced
339:ffbf0e41bcee 340:523e7b87c521
1 #!/usr/bin/python 1 #!/usr/bin/python
2 2
3 # do this first 3 import sys
4
5 # do this before importing custom modules
4 from mnist_config import * 6 from mnist_config import *
5 7
6 import sys 8 if not (len(sys.argv) > 1 and sys.argv[1] in \
9 ('test_jobman_entrypoint', 'run_local')):
10 # in those cases don't use isolated code, use dev code
11 isolate_code()
12
7 import os, os.path 13 import os, os.path
8 14
9 import numpy as N 15 import numpy as N
10 16
11 import theano 17 import theano
12 import theano.tensor as T 18 import theano.tensor as T
13 19
14 from crbm import CRBM, ConvolutionParams 20 from crbm import CRBM, ConvolutionParams
15 21
22 import pylearn, pylearn.version
16 from pylearn.datasets import MNIST 23 from pylearn.datasets import MNIST
17 from pylearn.io.image_tiling import tile_raster_images 24 from pylearn.io.image_tiling import tile_raster_images
18 25
19 import Image 26 import Image
20 27
21 from pylearn.io.seriestables import * 28 from pylearn.io.seriestables import *
22 import tables 29 import tables
23 30
31 import ift6266
32
24 import utils 33 import utils
34
35 if not os.path.exists(IMAGE_OUTPUT_DIR):
36 os.mkdir(IMAGE_OUTPUT_DIR)
37 elif os.path.isfile(IMAGE_OUTPUT_DIR):
38 print "IMAGE_OUTPUT_DIR is not a directory!"
39 sys.exit(1)
25 40
26 #def filename_from_time(suffix): 41 #def filename_from_time(suffix):
27 # import datetime 42 # import datetime
28 # return str(datetime.datetime.now()) + suffix + ".png" 43 # return str(datetime.datetime.now()) + suffix + ".png"
29 44
40 class MnistCrbm(object): 55 class MnistCrbm(object):
41 def __init__(self, state): 56 def __init__(self, state):
42 self.state = state 57 self.state = state
43 58
44 if TEST_CONFIG: 59 if TEST_CONFIG:
60 self.mnist = MNIST.first_1k()
61 print "Test config, so loaded MNIST first 1000"
62 else:
45 self.mnist = MNIST.full()#first_10k() 63 self.mnist = MNIST.full()#first_10k()
64 print "Loaded MNIST full"
46 65
47 self.cp = ConvolutionParams( \ 66 self.cp = ConvolutionParams( \
48 num_filters=state.num_filters, 67 num_filters=state.num_filters,
49 num_input_planes=1, 68 num_input_planes=1,
50 height_filters=state.filter_size, 69 height_filters=state.filter_size,