Mercurial > ift6266
diff deep/stacked_dae/nist_sda.py @ 284:8a3af19ae272
Enlevé mécanique pour limiter le nombre d'exemples utilisés (remplacé par paramètre dans l'appel au code de dataset), et ajouté option pour sauvegarde des poids à la fin de l'entraînement
author | fsavard |
---|---|
date | Wed, 24 Mar 2010 15:13:48 -0400 |
parents | 206374eed2fb |
children |
line wrap: on
line diff
--- a/deep/stacked_dae/nist_sda.py Wed Mar 24 14:36:55 2010 -0400 +++ b/deep/stacked_dae/nist_sda.py Wed Mar 24 15:13:48 2010 -0400 @@ -56,22 +56,30 @@ n_outs = 62 # 10 digits, 26*2 (lower, capitals) examples_per_epoch = NIST_ALL_TRAIN_SIZE + if rtt: + examples_per_epoch = rtt series = create_series(state.num_hidden_layers) print "Creating optimizer with state, ", state - optimizer = SdaSgdOptimizer(dataset=datasets.nist_all(), + dataset = None + if rtt: + dataset = datasets.nist_all(maxsize=rtt) + else: + dataset = datasets.nist_all() + + optimizer = SdaSgdOptimizer(dataset=dataset, hyperparameters=state, \ n_ins=n_ins, n_outs=n_outs,\ examples_per_epoch=examples_per_epoch, \ series=series, - max_minibatches=rtt) + save_params=SAVE_PARAMS) - optimizer.pretrain(datasets.nist_all()) + optimizer.pretrain(dataset) channel.save() - optimizer.finetune(datasets.nist_all()) + optimizer.finetune(dataset) channel.save() return channel.COMPLETE