Mercurial > ift6266
annotate code_tutoriel/dA.py @ 576:185d79636a20
now fits
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Sat, 07 Aug 2010 22:54:54 -0400 |
parents | 4bc5eeec6394 |
children |
rev | line source |
---|---|
165
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
1 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
2 This tutorial introduces denoising auto-encoders (dA) using Theano. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
3 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
4 Denoising autoencoders are the building blocks for SdA. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
5 They are based on auto-encoders as the ones used in Bengio et al. 2007. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
6 An autoencoder takes an input x and first maps it to a hidden representation |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
7 y = f_{\theta}(x) = s(Wx+b), parameterized by \theta={W,b}. The resulting |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
8 latent representation y is then mapped back to a "reconstructed" vector |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
9 z \in [0,1]^d in input space z = g_{\theta'}(y) = s(W'y + b'). The weight |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
10 matrix W' can optionally be constrained such that W' = W^T, in which case |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
11 the autoencoder is said to have tied weights. The network is trained such |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
12 that to minimize the reconstruction error (the error between x and z). |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
13 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
14 For the denosing autoencoder, during training, first x is corrupted into |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
15 \tilde{x}, where \tilde{x} is a partially destroyed version of x by means |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
16 of a stochastic mapping. Afterwards y is computed as before (using |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
17 \tilde{x}), y = s(W\tilde{x} + b) and z as s(W'y + b'). The reconstruction |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
18 error is now measured between z and the uncorrupted input x, which is |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
19 computed as the cross-entropy : |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
20 - \sum_{k=1}^d[ x_k \log z_k + (1-x_k) \log( 1-z_k)] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
21 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
22 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
23 References : |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
24 - P. Vincent, H. Larochelle, Y. Bengio, P.A. Manzagol: Extracting and |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
25 Composing Robust Features with Denoising Autoencoders, ICML'08, 1096-1103, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
26 2008 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
27 - Y. Bengio, P. Lamblin, D. Popovici, H. Larochelle: Greedy Layer-Wise |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
28 Training of Deep Networks, Advances in Neural Information Processing |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
29 Systems 19, 2007 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
30 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
31 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
32 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
33 import numpy, time, cPickle, gzip |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
34 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
35 import theano |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
36 import theano.tensor as T |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
37 from theano.tensor.shared_randomstreams import RandomStreams |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
38 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
39 from logistic_sgd import load_data |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
40 from utils import tile_raster_images |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
41 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
42 import PIL.Image |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
43 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
44 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
45 class dA(object): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
46 """Denoising Auto-Encoder class (dA) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
47 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
48 A denoising autoencoders tries to reconstruct the input from a corrupted |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
49 version of it by projecting it first in a latent space and reprojecting |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
50 it afterwards back in the input space. Please refer to Vincent et al.,2008 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
51 for more details. If x is the input then equation (1) computes a partially |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
52 destroyed version of x by means of a stochastic mapping q_D. Equation (2) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
53 computes the projection of the input into the latent space. Equation (3) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
54 computes the reconstruction of the input, while equation (4) computes the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
55 reconstruction error. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
56 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
57 .. math:: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
58 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
59 \tilde{x} ~ q_D(\tilde{x}|x) (1) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
60 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
61 y = s(W \tilde{x} + b) (2) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
62 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
63 x = s(W' y + b') (3) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
64 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
65 L(x,z) = -sum_{k=1}^d [x_k \log z_k + (1-x_k) \log( 1-z_k)] (4) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
66 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
67 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
68 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
69 def __init__(self, numpy_rng, theano_rng = None, input = None, n_visible= 784, n_hidden= 500, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
70 W = None, bhid = None, bvis = None): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
71 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
72 Initialize the dA class by specifying the number of visible units (the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
73 dimension d of the input ), the number of hidden units ( the dimension |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
74 d' of the latent or hidden space ) and the corruption level. The |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
75 constructor also receives symbolic variables for the input, weights and |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
76 bias. Such a symbolic variables are useful when, for example the input is |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
77 the result of some computations, or when weights are shared between the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
78 dA and an MLP layer. When dealing with SdAs this always happens, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
79 the dA on layer 2 gets as input the output of the dA on layer 1, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
80 and the weights of the dA are used in the second stage of training |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
81 to construct an MLP. |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
82 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
83 :type numpy_rng: numpy.random.RandomState |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
84 :param numpy_rng: number random generator used to generate weights |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
85 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
86 :type theano_rng: theano.tensor.shared_randomstreams.RandomStreams |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
87 :param theano_rng: Theano random generator; if None is given one is generated |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
88 based on a seed drawn from `rng` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
89 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
90 :type input: theano.tensor.TensorType |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
91 :paran input: a symbolic description of the input or None for standalone |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
92 dA |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
93 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
94 :type n_visible: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
95 :param n_visible: number of visible units |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
96 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
97 :type n_hidden: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
98 :param n_hidden: number of hidden units |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
99 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
100 :type W: theano.tensor.TensorType |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
101 :param W: Theano variable pointing to a set of weights that should be |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
102 shared belong the dA and another architecture; if dA should |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
103 be standalone set this to None |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
104 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
105 :type bhid: theano.tensor.TensorType |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
106 :param bhid: Theano variable pointing to a set of biases values (for |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
107 hidden units) that should be shared belong dA and another |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
108 architecture; if dA should be standalone set this to None |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
109 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
110 :type bvis: theano.tensor.TensorType |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
111 :param bvis: Theano variable pointing to a set of biases values (for |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
112 visible units) that should be shared belong dA and another |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
113 architecture; if dA should be standalone set this to None |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
114 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
115 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
116 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
117 self.n_visible = n_visible |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
118 self.n_hidden = n_hidden |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
119 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
120 # create a Theano random generator that gives symbolic random values |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
121 if not theano_rng : |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
122 theano_rng = RandomStreams(rng.randint(2**30)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
123 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
124 # note : W' was written as `W_prime` and b' as `b_prime` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
125 if not W: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
126 # W is initialized with `initial_W` which is uniformely sampled |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
127 # from -6./sqrt(n_visible+n_hidden) and 6./sqrt(n_hidden+n_visible) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
128 # the output of uniform if converted using asarray to dtype |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
129 # theano.config.floatX so that the code is runable on GPU |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
130 initial_W = numpy.asarray( numpy_rng.uniform( |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
131 low = -numpy.sqrt(6./(n_hidden+n_visible)), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
132 high = numpy.sqrt(6./(n_hidden+n_visible)), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
133 size = (n_visible, n_hidden)), dtype = theano.config.floatX) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
134 W = theano.shared(value = initial_W, name ='W') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
135 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
136 if not bvis: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
137 bvis = theano.shared(value = numpy.zeros(n_visible, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
138 dtype = theano.config.floatX)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
139 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
140 if not bhid: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
141 bhid = theano.shared(value = numpy.zeros(n_hidden, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
142 dtype = theano.config.floatX)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
143 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
144 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
145 self.W = W |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
146 # b corresponds to the bias of the hidden |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
147 self.b = bhid |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
148 # b_prime corresponds to the bias of the visible |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
149 self.b_prime = bvis |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
150 # tied weights, therefore W_prime is W transpose |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
151 self.W_prime = self.W.T |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
152 self.theano_rng = theano_rng |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
153 # if no input is given, generate a variable representing the input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
154 if input == None : |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
155 # we use a matrix because we expect a minibatch of several examples, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
156 # each example being a row |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
157 self.x = T.dmatrix(name = 'input') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
158 else: |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
159 self.x = input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
160 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
161 self.params = [self.W, self.b, self.b_prime] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
162 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
163 def get_corrupted_input(self, input, corruption_level): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
164 """ This function keeps ``1-corruption_level`` entries of the inputs the same |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
165 and zero-out randomly selected subset of size ``coruption_level`` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
166 Note : first argument of theano.rng.binomial is the shape(size) of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
167 random numbers that it should produce |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
168 second argument is the number of trials |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
169 third argument is the probability of success of any trial |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
170 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
171 this will produce an array of 0s and 1s where 1 has a probability of |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
172 1 - ``corruption_level`` and 0 with ``corruption_level`` |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
173 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
174 return self.theano_rng.binomial( size = input.shape, n = 1, prob = 1 - corruption_level) * input |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
175 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
176 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
177 def get_hidden_values(self, input): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
178 """ Computes the values of the hidden layer """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
179 return T.nnet.sigmoid(T.dot(input, self.W) + self.b) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
180 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
181 def get_reconstructed_input(self, hidden ): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
182 """ Computes the reconstructed input given the values of the hidden layer """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
183 return T.nnet.sigmoid(T.dot(hidden, self.W_prime) + self.b_prime) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
184 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
185 def get_cost_updates(self, corruption_level, learning_rate): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
186 """ This function computes the cost and the updates for one trainng |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
187 step of the dA """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
188 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
189 tilde_x = self.get_corrupted_input(self.x, corruption_level) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
190 y = self.get_hidden_values( tilde_x) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
191 z = self.get_reconstructed_input(y) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
192 # note : we sum over the size of a datapoint; if we are using minibatches, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
193 # L will be a vector, with one entry per example in minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
194 L = - T.sum( self.x*T.log(z) + (1-self.x)*T.log(1-z), axis=1 ) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
195 # note : L is now a vector, where each element is the cross-entropy cost |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
196 # of the reconstruction of the corresponding example of the |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
197 # minibatch. We need to compute the average of all these to get |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
198 # the cost of the minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
199 cost = T.mean(L) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
200 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
201 # compute the gradients of the cost of the `dA` with respect |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
202 # to its parameters |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
203 gparams = T.grad(cost, self.params) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
204 # generate the list of updates |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
205 updates = {} |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
206 for param, gparam in zip(self.params, gparams): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
207 updates[param] = param - learning_rate*gparam |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
208 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
209 return (cost, updates) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
210 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
211 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
212 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
213 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
214 def test_dA( learning_rate = 0.1, training_epochs = 15, dataset ='mnist.pkl.gz' ): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
215 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
216 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
217 This demo is tested on MNIST |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
218 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
219 :type learning_rate: float |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
220 :param learning_rate: learning rate used for training the DeNosing AutoEncoder |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
221 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
222 :type training_epochs: int |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
223 :param training_epochs: number of epochs used for training |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
224 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
225 :type dataset: string |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
226 :param dataset: path to the picked dataset |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
227 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
228 """ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
229 datasets = load_data(dataset) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
230 train_set_x, train_set_y = datasets[0] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
231 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
232 batch_size = 20 # size of the minibatch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
233 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
234 # compute number of minibatches for training, validation and testing |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
235 n_train_batches = train_set_x.value.shape[0] / batch_size |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
236 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
237 # allocate symbolic variables for the data |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
238 index = T.lscalar() # index to a [mini]batch |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
239 x = T.matrix('x') # the data is presented as rasterized images |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
240 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
241 #################################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
242 # BUILDING THE MODEL NO CORRUPTION # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
243 #################################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
244 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
245 rng = numpy.random.RandomState(123) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
246 theano_rng = RandomStreams( rng.randint(2**30)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
247 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
248 da = dA(numpy_rng = rng, theano_rng = theano_rng, input = x, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
249 n_visible = 28*28, n_hidden = 500) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
250 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
251 cost, updates = da.get_cost_updates(corruption_level = 0., |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
252 learning_rate = learning_rate) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
253 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
254 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
255 train_da = theano.function([index], cost, updates = updates, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
256 givens = {x:train_set_x[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
257 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
258 start_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
259 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
260 ############ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
261 # TRAINING # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
262 ############ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
263 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
264 # go through training epochs |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
265 for epoch in xrange(training_epochs): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
266 # go through trainng set |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
267 c = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
268 for batch_index in xrange(n_train_batches): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
269 c.append(train_da(batch_index)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
270 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
271 print 'Training epoch %d, cost '%epoch, numpy.mean(c) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
272 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
273 end_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
274 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
275 training_time = (end_time - start_time) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
276 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
277 print ('Training took %f minutes' %(training_time/60.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
278 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
279 image = PIL.Image.fromarray(tile_raster_images( X = da.W.value.T, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
280 img_shape = (28,28),tile_shape = (10,10), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
281 tile_spacing=(1,1))) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
282 image.save('filters_corruption_0.png') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
283 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
284 ##################################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
285 # BUILDING THE MODEL CORRUPTION 30% # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
286 ##################################### |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
287 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
288 rng = numpy.random.RandomState(123) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
289 theano_rng = RandomStreams( rng.randint(2**30)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
290 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
291 da = dA(numpy_rng = rng, theano_rng = theano_rng, input = x, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
292 n_visible = 28*28, n_hidden = 500) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
293 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
294 cost, updates = da.get_cost_updates(corruption_level = 0.3, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
295 learning_rate = learning_rate) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
296 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
297 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
298 train_da = theano.function([index], cost, updates = updates, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
299 givens = {x:train_set_x[index*batch_size:(index+1)*batch_size]}) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
300 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
301 start_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
302 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
303 ############ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
304 # TRAINING # |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
305 ############ |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
306 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
307 # go through training epochs |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
308 for epoch in xrange(training_epochs): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
309 # go through trainng set |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
310 c = [] |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
311 for batch_index in xrange(n_train_batches): |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
312 c.append(train_da(batch_index)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
313 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
314 print 'Training epoch %d, cost '%epoch, numpy.mean(c) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
315 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
316 end_time = time.clock() |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
317 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
318 training_time = (end_time - start_time) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
319 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
320 print ('Training took %f minutes' %(training_time/60.)) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
321 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
322 image = PIL.Image.fromarray(tile_raster_images( X = da.W.value.T, |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
323 img_shape = (28,28),tile_shape = (10,10), |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
324 tile_spacing=(1,1))) |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
325 image.save('filters_corruption_30.png') |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
326 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
327 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
328 |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
329 if __name__ == '__main__': |
4bc5eeec6394
Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff
changeset
|
330 test_dA() |