Mercurial > ift6266
annotate code_tutoriel/convolutional_mlp.py @ 489:ee9836baade3
merge
author | dumitru@dumitru.mtv.corp.google.com |
---|---|
date | Mon, 31 May 2010 19:07:59 -0700 |
parents | 4bc5eeec6394 |
children |
rev | line source |
---|---|
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
1 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
2 This tutorial introduces the LeNet5 neural network architecture using Theano. LeNet5 is a |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
3 convolutional neural network, good for classifying images. This tutorial shows how to build the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
4 architecture, and comes with all the hyper-parameters you need to reproduce the paper's MNIST |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
5 results. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
6 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
7 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
8 This implementation simplifies the model in the following ways: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
9 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
10 - LeNetConvPool doesn't implement location-specific gain and bias parameters |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
11 - LeNetConvPool doesn't implement pooling by average, it implements pooling by max. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
12 - Digit classification is implemented with a logistic regression rather than an RBF network |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
13 - LeNet5 was not fully-connected convolutions at second layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
14 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
15 References: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
16 - Y. LeCun, L. Bottou, Y. Bengio and P. Haffner: Gradient-Based Learning Applied to Document |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
17 Recognition, Proceedings of the IEEE, 86(11):2278-2324, November 1998. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
18 http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
19 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
20 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
21 import numpy, time, cPickle, gzip |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
22 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
23 import theano |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
24 import theano.tensor as T |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
25 from theano.tensor.signal import downsample |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
26 from theano.tensor.nnet import conv |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
27 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
28 from logistic_sgd import LogisticRegression, load_data |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
29 from mlp import HiddenLayer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
30 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
31 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
32 class LeNetConvPoolLayer(object): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
33 """Pool Layer of a convolutional network """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
34 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
35 def __init__(self, rng, input, filter_shape, image_shape, poolsize=(2,2)): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
36 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
37 Allocate a LeNetConvPoolLayer with shared variable internal parameters. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
38 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
39 :type rng: numpy.random.RandomState |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
40 :param rng: a random number generator used to initialize weights |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
41 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
42 :type input: theano.tensor.dtensor4 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
43 :param input: symbolic image tensor, of shape image_shape |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
44 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
45 :type filter_shape: tuple or list of length 4 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
46 :param filter_shape: (number of filters, num input feature maps, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
47 filter height,filter width) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
48 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
49 :type image_shape: tuple or list of length 4 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
50 :param image_shape: (batch size, num input feature maps, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
51 image height, image width) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
52 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
53 :type poolsize: tuple or list of length 2 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
54 :param poolsize: the downsampling (pooling) factor (#rows,#cols) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
55 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
56 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
57 assert image_shape[1]==filter_shape[1] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
58 self.input = input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
59 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
60 # initialize weights to temporary values until we know the shape of the output feature |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
61 # maps |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
62 W_values = numpy.zeros(filter_shape, dtype=theano.config.floatX) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
63 self.W = theano.shared(value = W_values) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
64 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
65 # the bias is a 1D tensor -- one bias per output feature map |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
66 b_values = numpy.zeros((filter_shape[0],), dtype= theano.config.floatX) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
67 self.b = theano.shared(value= b_values) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
68 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
69 # convolve input feature maps with filters |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
70 conv_out = conv.conv2d(input = input, filters = self.W, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
71 filter_shape=filter_shape, image_shape=image_shape) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
72 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
73 # there are "num input feature maps * filter height * filter width" inputs |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
74 # to each hidden unit |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
75 fan_in = numpy.prod(filter_shape[1:]) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
76 # each unit in the lower layer receives a gradient from: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
77 # "num output feature maps * filter height * filter width" / pooling size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
78 fan_out = filter_shape[0] * numpy.prod(filter_shape[2:]) / numpy.prod(poolsize) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
79 # replace weight values with random weights |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
80 W_bound = numpy.sqrt(6./(fan_in + fan_out)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
81 self.W.value = numpy.asarray( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
82 rng.uniform(low=-W_bound, high=W_bound, size=filter_shape), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
83 dtype = theano.config.floatX) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
84 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
85 # downsample each feature map individually, using maxpooling |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
86 pooled_out = downsample.max_pool2D( input = conv_out, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
87 ds = poolsize, ignore_border=True) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
88 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
89 # add the bias term. Since the bias is a vector (1D array), we first |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
90 # reshape it to a tensor of shape (1,n_filters,1,1). Each bias will thus |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
91 # be broadcasted across mini-batches and feature map width & height |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
92 self.output = T.tanh(pooled_out + self.b.dimshuffle('x', 0, 'x', 'x')) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
93 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
94 # store parameters of this layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
95 self.params = [self.W, self.b] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
96 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
97 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
98 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
99 def evaluate_lenet5(learning_rate=0.1, n_epochs=200, dataset='mnist.pkl.gz', nkerns=[20,50]): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
100 """ Demonstrates lenet on MNIST dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
101 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
102 :type learning_rate: float |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
103 :param learning_rate: learning rate used (factor for the stochastic |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
104 gradient) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
105 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
106 :type n_epochs: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
107 :param n_epochs: maximal number of epochs to run the optimizer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
108 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
109 :type dataset: string |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
110 :param dataset: path to the dataset used for training /testing (MNIST here) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
111 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
112 :type nkerns: list of ints |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
113 :param nkerns: number of kernels on each layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
114 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
115 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
116 rng = numpy.random.RandomState(23455) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
117 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
118 datasets = load_data(dataset) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
119 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
120 train_set_x, train_set_y = datasets[0] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
121 valid_set_x, valid_set_y = datasets[1] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
122 test_set_x , test_set_y = datasets[2] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
123 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
124 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
125 batch_size = 500 # size of the minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
126 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
127 # compute number of minibatches for training, validation and testing |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
128 n_train_batches = train_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
129 n_valid_batches = valid_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
130 n_test_batches = test_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
131 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
132 # allocate symbolic variables for the data |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
133 index = T.lscalar() # index to a [mini]batch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
134 x = T.matrix('x') # the data is presented as rasterized images |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
135 y = T.ivector('y') # the labels are presented as 1D vector of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
136 # [int] labels |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
137 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
138 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
139 ishape = (28,28) # this is the size of MNIST images |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
140 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
141 ###################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
142 # BUILD ACTUAL MODEL # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
143 ###################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
144 print '... building the model' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
145 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
146 # Reshape matrix of rasterized images of shape (batch_size,28*28) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
147 # to a 4D tensor, compatible with our LeNetConvPoolLayer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
148 layer0_input = x.reshape((batch_size,1,28,28)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
149 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
150 # Construct the first convolutional pooling layer: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
151 # filtering reduces the image size to (28-5+1,28-5+1)=(24,24) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
152 # maxpooling reduces this further to (24/2,24/2) = (12,12) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
153 # 4D output tensor is thus of shape (batch_size,nkerns[0],12,12) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
154 layer0 = LeNetConvPoolLayer(rng, input=layer0_input, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
155 image_shape=(batch_size,1,28,28), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
156 filter_shape=(nkerns[0],1,5,5), poolsize=(2,2)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
157 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
158 # Construct the second convolutional pooling layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
159 # filtering reduces the image size to (12-5+1,12-5+1)=(8,8) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
160 # maxpooling reduces this further to (8/2,8/2) = (4,4) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
161 # 4D output tensor is thus of shape (nkerns[0],nkerns[1],4,4) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
162 layer1 = LeNetConvPoolLayer(rng, input=layer0.output, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
163 image_shape=(batch_size,nkerns[0],12,12), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
164 filter_shape=(nkerns[1],nkerns[0],5,5), poolsize=(2,2)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
165 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
166 # the TanhLayer being fully-connected, it operates on 2D matrices of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
167 # shape (batch_size,num_pixels) (i.e matrix of rasterized images). |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
168 # This will generate a matrix of shape (20,32*4*4) = (20,512) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
169 layer2_input = layer1.output.flatten(2) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
170 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
171 # construct a fully-connected sigmoidal layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
172 layer2 = HiddenLayer(rng, input=layer2_input, n_in=nkerns[1]*4*4, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
173 n_out=500, activation = T.tanh) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
174 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
175 # classify the values of the fully-connected sigmoidal layer |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
176 layer3 = LogisticRegression(input=layer2.output, n_in=500, n_out=10) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
177 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
178 # the cost we minimize during training is the NLL of the model |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
179 cost = layer3.negative_log_likelihood(y) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
180 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
181 # create a function to compute the mistakes that are made by the model |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
182 test_model = theano.function([index], layer3.errors(y), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
183 givens = { |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
184 x: test_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
185 y: test_set_y[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
186 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
187 validate_model = theano.function([index], layer3.errors(y), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
188 givens = { |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
189 x: valid_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
190 y: valid_set_y[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
191 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
192 # create a list of all model parameters to be fit by gradient descent |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
193 params = layer3.params+ layer2.params+ layer1.params + layer0.params |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
194 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
195 # create a list of gradients for all model parameters |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
196 grads = T.grad(cost, params) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
197 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
198 # train_model is a function that updates the model parameters by SGD |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
199 # Since this model has many parameters, it would be tedious to manually |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
200 # create an update rule for each model parameter. We thus create the updates |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
201 # dictionary by automatically looping over all (params[i],grads[i]) pairs. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
202 updates = {} |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
203 for param_i, grad_i in zip(params, grads): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
204 updates[param_i] = param_i - learning_rate * grad_i |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
205 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
206 train_model = theano.function([index], cost, updates=updates, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
207 givens = { |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
208 x: train_set_x[index*batch_size:(index+1)*batch_size], |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
209 y: train_set_y[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
210 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
211 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
212 ############### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
213 # TRAIN MODEL # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
214 ############### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
215 print '... training' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
216 # early-stopping parameters |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
217 patience = 10000 # look as this many examples regardless |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
218 patience_increase = 2 # wait this much longer when a new best is |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
219 # found |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
220 improvement_threshold = 0.995 # a relative improvement of this much is |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
221 # considered significant |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
222 validation_frequency = min(n_train_batches, patience/2) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
223 # go through this many |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
224 # minibatche before checking the network |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
225 # on the validation set; in this case we |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
226 # check every epoch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
227 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
228 best_params = None |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
229 best_validation_loss = float('inf') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
230 best_iter = 0 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
231 test_score = 0. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
232 start_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
233 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
234 epoch = 0 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
235 done_looping = False |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
236 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
237 while (epoch < n_epochs) and (not done_looping): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
238 epoch = epoch + 1 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
239 for minibatch_index in xrange(n_train_batches): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
240 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
241 iter = epoch * n_train_batches + minibatch_index |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
242 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
243 if iter %100 == 0: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
244 print 'training @ iter = ', iter |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
245 cost_ij = train_model(minibatch_index) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
246 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
247 if (iter+1) % validation_frequency == 0: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
248 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
249 # compute zero-one loss on validation set |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
250 validation_losses = [validate_model(i) for i in xrange(n_valid_batches)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
251 this_validation_loss = numpy.mean(validation_losses) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
252 print('epoch %i, minibatch %i/%i, validation error %f %%' % \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
253 (epoch, minibatch_index+1, n_train_batches, \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
254 this_validation_loss*100.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
255 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
256 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
257 # if we got the best validation score until now |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
258 if this_validation_loss < best_validation_loss: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
259 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
260 #improve patience if loss improvement is good enough |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
261 if this_validation_loss < best_validation_loss * \ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
262 improvement_threshold : |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
263 patience = max(patience, iter * patience_increase) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
264 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
265 # save best validation score and iteration number |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
266 best_validation_loss = this_validation_loss |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
267 best_iter = iter |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
268 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
269 # test it on the test set |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
270 test_losses = [test_model(i) for i in xrange(n_test_batches)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
271 test_score = numpy.mean(test_losses) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
272 print((' epoch %i, minibatch %i/%i, test error of best ' |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
273 'model %f %%') % |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
274 (epoch, minibatch_index+1, n_train_batches, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
275 test_score*100.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
276 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
277 if patience <= iter : |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
278 done_looping = False |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
279 break |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
280 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
281 end_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
282 print('Optimization complete.') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
283 print('Best validation score of %f %% obtained at iteration %i,'\ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
284 'with test performance %f %%' % |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
285 (best_validation_loss * 100., best_iter, test_score*100.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
286 print('The code ran for %f minutes' % ((end_time-start_time)/60.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
287 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
288 if __name__ == '__main__': |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
289 evaluate_lenet5() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
290 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
291 def experiment(state, channel): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
292 evaluate_lenet5(state.learning_rate, dataset=state.dataset) |