Mercurial > ift6266
comparison deep/stacked_dae/nist_sda.py @ 284:8a3af19ae272
Enlevé mécanique pour limiter le nombre d'exemples utilisés (remplacé par paramètre dans l'appel au code de dataset), et ajouté option pour sauvegarde des poids à la fin de l'entraînement
author | fsavard |
---|---|
date | Wed, 24 Mar 2010 15:13:48 -0400 |
parents | 206374eed2fb |
children |
comparison
equal
deleted
inserted
replaced
279:206374eed2fb | 284:8a3af19ae272 |
---|---|
54 | 54 |
55 n_ins = 32*32 | 55 n_ins = 32*32 |
56 n_outs = 62 # 10 digits, 26*2 (lower, capitals) | 56 n_outs = 62 # 10 digits, 26*2 (lower, capitals) |
57 | 57 |
58 examples_per_epoch = NIST_ALL_TRAIN_SIZE | 58 examples_per_epoch = NIST_ALL_TRAIN_SIZE |
59 if rtt: | |
60 examples_per_epoch = rtt | |
59 | 61 |
60 series = create_series(state.num_hidden_layers) | 62 series = create_series(state.num_hidden_layers) |
61 | 63 |
62 print "Creating optimizer with state, ", state | 64 print "Creating optimizer with state, ", state |
63 | 65 |
64 optimizer = SdaSgdOptimizer(dataset=datasets.nist_all(), | 66 dataset = None |
67 if rtt: | |
68 dataset = datasets.nist_all(maxsize=rtt) | |
69 else: | |
70 dataset = datasets.nist_all() | |
71 | |
72 optimizer = SdaSgdOptimizer(dataset=dataset, | |
65 hyperparameters=state, \ | 73 hyperparameters=state, \ |
66 n_ins=n_ins, n_outs=n_outs,\ | 74 n_ins=n_ins, n_outs=n_outs,\ |
67 examples_per_epoch=examples_per_epoch, \ | 75 examples_per_epoch=examples_per_epoch, \ |
68 series=series, | 76 series=series, |
69 max_minibatches=rtt) | 77 save_params=SAVE_PARAMS) |
70 | 78 |
71 optimizer.pretrain(datasets.nist_all()) | 79 optimizer.pretrain(dataset) |
72 channel.save() | 80 channel.save() |
73 | 81 |
74 optimizer.finetune(datasets.nist_all()) | 82 optimizer.finetune(dataset) |
75 channel.save() | 83 channel.save() |
76 | 84 |
77 return channel.COMPLETE | 85 return channel.COMPLETE |
78 | 86 |
79 # These Series objects are used to save various statistics | 87 # These Series objects are used to save various statistics |