annotate deep/stacked_dae/v_guillaume/nist_sda.py @ 631:510220effb14

corrections demandees par reviewer
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Sat, 19 Mar 2011 22:44:53 -0400
parents 0ca069550abd
children
rev   line source
436
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
1 #!/usr/bin/python
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
2 # -*- coding: utf-8 -*-
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
3 # coding: utf-8
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
4
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
5 import ift6266
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
6 import pylearn
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
7
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
8 import numpy
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
9 import theano
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
10 import time
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
11
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
12 import pylearn.version
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
13 import theano.tensor as T
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
14 from theano.tensor.shared_randomstreams import RandomStreams
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
15
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
16 import copy
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
17 import sys
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
18 import os
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
19 import os.path
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
20
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
21 from jobman import DD
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
22 import jobman, jobman.sql
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
23 from pylearn.io import filetensor
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
24
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
25 from utils import produit_cartesien_jobs
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
26 from copy import copy
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
27
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
28 from sgd_optimization import SdaSgdOptimizer
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
29
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
30 #from ift6266.utils.scalar_series import *
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
31 from ift6266.utils.seriestables import *
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
32 import tables
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
33
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
34 from ift6266 import datasets
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
35 from config import *
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
36
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
37 '''
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
38 Function called by jobman upon launching each job
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
39 Its path is the one given when inserting jobs: see EXPERIMENT_PATH
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
40 '''
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
41 def jobman_entrypoint(state, channel):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
42 # record mercurial versions of each package
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
43 pylearn.version.record_versions(state,[theano,ift6266,pylearn])
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
44 # TODO: remove this, bad for number of simultaneous requests on DB
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
45 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
46
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
47 # For test runs, we don't want to use the whole dataset so
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
48 # reduce it to fewer elements if asked to.
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
49 rtt = None
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
50 if state.has_key('reduce_train_to'):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
51 rtt = state['reduce_train_to']
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
52 elif REDUCE_TRAIN_TO:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
53 rtt = REDUCE_TRAIN_TO
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
54
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
55 if state.has_key('decrease_lr'):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
56 decrease_lr = state['decrease_lr']
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
57 else :
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
58 decrease_lr = 0
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
59
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
60 if state.has_key('decrease_lr_pretrain'):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
61 dec=state['decrease_lr_pretrain']
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
62 else :
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
63 dec=0
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
64
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
65 n_ins = 32*32
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
66
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
67 if state.has_key('subdataset'):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
68 subdataset_name=state['subdataset']
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
69 else:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
70 subdataset_name=SUBDATASET_NIST
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
71
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
72 #n_outs = 62 # 10 digits, 26*2 (lower, capitals)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
73 if subdataset_name == "upper":
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
74 n_outs = 26
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
75 subdataset = datasets.nist_upper()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
76 examples_per_epoch = NIST_UPPER_TRAIN_SIZE
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
77 elif subdataset_name == "lower":
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
78 n_outs = 26
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
79 subdataset = datasets.nist_lower()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
80 examples_per_epoch = NIST_LOWER_TRAIN_SIZE
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
81 elif subdataset_name == "digits":
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
82 n_outs = 10
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
83 subdataset = datasets.nist_digits()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
84 examples_per_epoch = NIST_DIGITS_TRAIN_SIZE
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
85 else:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
86 n_outs = 62
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
87 subdataset = datasets.nist_all()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
88 examples_per_epoch = NIST_ALL_TRAIN_SIZE
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
89
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
90 print 'Using subdataset ', subdataset_name
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
91
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
92 #To be sure variables will not be only in the if statement
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
93 PATH = ''
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
94 nom_reptrain = ''
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
95 nom_serie = ""
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
96 if state['pretrain_choice'] == 0:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
97 nom_serie="series_NIST.h5"
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
98 elif state['pretrain_choice'] == 1:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
99 nom_serie="series_P07.h5"
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
100
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
101 series = create_series(state.num_hidden_layers,nom_serie)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
102
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
103
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
104 print "Creating optimizer with state, ", state
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
105
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
106 optimizer = SdaSgdOptimizer(dataset_name=subdataset_name,\
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
107 dataset=subdataset,\
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
108 hyperparameters=state, \
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
109 n_ins=n_ins, n_outs=n_outs,\
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
110 examples_per_epoch=examples_per_epoch, \
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
111 series=series,
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
112 max_minibatches=rtt)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
113
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
114 parameters=[]
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
115 #Number of files of P07 used for pretraining
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
116 nb_file=0
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
117
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
118 print('\n\tpretraining with NIST\n')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
119
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
120 optimizer.pretrain(subdataset, decrease = dec)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
121
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
122 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
123
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
124 #Set some of the parameters used for the finetuning
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
125 if state.has_key('finetune_set'):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
126 finetune_choice=state['finetune_set']
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
127 else:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
128 finetune_choice=FINETUNE_SET
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
129
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
130 if state.has_key('max_finetuning_epochs'):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
131 max_finetune_epoch_NIST=state['max_finetuning_epochs']
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
132 else:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
133 max_finetune_epoch_NIST=MAX_FINETUNING_EPOCHS
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
134
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
135 if state.has_key('max_finetuning_epochs_P07'):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
136 max_finetune_epoch_P07=state['max_finetuning_epochs_P07']
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
137 else:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
138 max_finetune_epoch_P07=max_finetune_epoch_NIST
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
139
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
140 #Decide how the finetune is done
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
141
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
142 if finetune_choice == 0:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
143 print('\n\n\tfinetune with NIST\n\n')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
144 optimizer.reload_parameters('params_pretrain.txt')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
145 optimizer.finetune(subdataset,subdataset,max_finetune_epoch_NIST,ind_test=1,decrease=decrease_lr)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
146 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
147 if finetune_choice == 1:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
148 print('\n\n\tfinetune with P07\n\n')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
149 optimizer.reload_parameters('params_pretrain.txt')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
150 optimizer.finetune(datasets.nist_P07(),datasets.nist_all(),max_finetune_epoch_P07,ind_test=0,decrease=decrease_lr)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
151 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
152 if finetune_choice == 2:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
153 print('\n\n\tfinetune with P07 followed by NIST\n\n')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
154 optimizer.reload_parameters('params_pretrain.txt')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
155 optimizer.finetune(datasets.nist_P07(),datasets.nist_all(),max_finetune_epoch_P07,ind_test=20,decrease=decrease_lr)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
156 optimizer.finetune(datasets.nist_all(),datasets.nist_P07(),max_finetune_epoch_NIST,ind_test=21,decrease=decrease_lr)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
157 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
158 if finetune_choice == 3:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
159 print('\n\n\tfinetune with NIST only on the logistic regression on top (but validation on P07).\n\
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
160 All hidden units output are input of the logistic regression\n\n')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
161 optimizer.reload_parameters('params_pretrain.txt')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
162 optimizer.finetune(datasets.nist_all(),datasets.nist_P07(),max_finetune_epoch_NIST,ind_test=1,special=1,decrease=decrease_lr)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
163
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
164
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
165 if finetune_choice==-1:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
166 print('\nSERIE OF 4 DIFFERENT FINETUNINGS')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
167 print('\n\n\tfinetune with NIST\n\n')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
168 sys.stdout.flush()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
169 optimizer.reload_parameters('params_pretrain.txt')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
170 optimizer.finetune(datasets.nist_all(),datasets.nist_P07(),max_finetune_epoch_NIST,ind_test=1,decrease=decrease_lr)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
171 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
172 print('\n\n\tfinetune with P07\n\n')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
173 sys.stdout.flush()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
174 optimizer.reload_parameters('params_pretrain.txt')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
175 optimizer.finetune(datasets.nist_P07(),datasets.nist_all(),max_finetune_epoch_P07,ind_test=0,decrease=decrease_lr)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
176 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
177 print('\n\n\tfinetune with P07 (done earlier) followed by NIST (written here)\n\n')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
178 sys.stdout.flush()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
179 optimizer.reload_parameters('params_finetune_P07.txt')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
180 optimizer.finetune(datasets.nist_all(),datasets.nist_P07(),max_finetune_epoch_NIST,ind_test=21,decrease=decrease_lr)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
181 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
182 print('\n\n\tfinetune with NIST only on the logistic regression on top.\n\
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
183 All hidden units output are input of the logistic regression\n\n')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
184 sys.stdout.flush()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
185 optimizer.reload_parameters('params_pretrain.txt')
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
186 optimizer.finetune(datasets.nist_all(),datasets.nist_P07(),max_finetune_epoch_NIST,ind_test=1,special=1,decrease=decrease_lr)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
187 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
188
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
189 channel.save()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
190
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
191 return channel.COMPLETE
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
192
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
193 # These Series objects are used to save various statistics
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
194 # during the training.
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
195 def create_series(num_hidden_layers, nom_serie):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
196
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
197 # Replace series we don't want to save with DummySeries, e.g.
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
198 # series['training_error'] = DummySeries()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
199
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
200 series = {}
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
201
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
202 basedir = os.getcwd()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
203
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
204 h5f = tables.openFile(os.path.join(basedir, nom_serie), "w")
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
205
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
206 # reconstruction
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
207 reconstruction_base = \
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
208 ErrorSeries(error_name="reconstruction_error",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
209 table_name="reconstruction_error",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
210 hdf5_file=h5f,
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
211 index_names=('epoch','minibatch'),
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
212 title="Reconstruction error (mean over "+str(REDUCE_EVERY)+" minibatches)")
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
213 series['reconstruction_error'] = \
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
214 AccumulatorSeriesWrapper(base_series=reconstruction_base,
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
215 reduce_every=REDUCE_EVERY)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
216
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
217 # train
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
218 training_base = \
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
219 ErrorSeries(error_name="training_error",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
220 table_name="training_error",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
221 hdf5_file=h5f,
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
222 index_names=('epoch','minibatch'),
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
223 title="Training error (mean over "+str(REDUCE_EVERY)+" minibatches)")
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
224 series['training_error'] = \
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
225 AccumulatorSeriesWrapper(base_series=training_base,
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
226 reduce_every=REDUCE_EVERY)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
227
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
228 # valid and test are not accumulated/mean, saved directly
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
229 series['validation_error'] = \
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
230 ErrorSeries(error_name="validation_error",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
231 table_name="validation_error",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
232 hdf5_file=h5f,
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
233 index_names=('epoch','minibatch'))
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
234
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
235 series['test_error'] = \
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
236 ErrorSeries(error_name="test_error",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
237 table_name="test_error",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
238 hdf5_file=h5f,
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
239 index_names=('epoch','minibatch'))
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
240
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
241 param_names = []
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
242 for i in range(num_hidden_layers):
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
243 param_names += ['layer%d_W'%i, 'layer%d_b'%i, 'layer%d_bprime'%i]
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
244 param_names += ['logreg_layer_W', 'logreg_layer_b']
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
245
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
246 # comment out series we don't want to save
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
247 series['params'] = SharedParamsStatisticsWrapper(
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
248 new_group_name="params",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
249 base_group="/",
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
250 arrays_names=param_names,
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
251 hdf5_file=h5f,
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
252 index_names=('epoch',))
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
253
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
254 return series
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
255
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
256 # Perform insertion into the Postgre DB based on combination
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
257 # of hyperparameter values above
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
258 # (see comment for produit_cartesien_jobs() to know how it works)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
259 def jobman_insert_nist():
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
260 jobs = produit_cartesien_jobs(JOB_VALS)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
261
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
262 db = jobman.sql.db(JOBDB)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
263 for job in jobs:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
264 job.update({jobman.sql.EXPERIMENT: EXPERIMENT_PATH})
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
265 jobman.sql.insert_dict(job, db)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
266
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
267 print "inserted"
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
268
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
269 if __name__ == '__main__':
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
270
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
271 args = sys.argv[1:]
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
272
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
273 #if len(args) > 0 and args[0] == 'load_nist':
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
274 # test_load_nist()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
275
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
276 if len(args) > 0 and args[0] == 'jobman_insert':
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
277 jobman_insert_nist()
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
278
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
279 elif len(args) > 0 and args[0] == 'test_jobman_entrypoint':
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
280 chanmock = DD({'COMPLETE':0,'save':(lambda:None)})
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
281 jobman_entrypoint(DD(DEFAULT_HP_NIST), chanmock)
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
282
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
283 else:
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
284 print "Bad arguments"
0ca069550abd Added : single class version of SDA
Guillaume Sicard <guitch21@gmail.com>
parents:
diff changeset
285