comparison deep/stacked_dae/nist_sda.py @ 284:8a3af19ae272

Enlevé mécanique pour limiter le nombre d'exemples utilisés (remplacé par paramètre dans l'appel au code de dataset), et ajouté option pour sauvegarde des poids à la fin de l'entraînement
author fsavard
date Wed, 24 Mar 2010 15:13:48 -0400
parents 206374eed2fb
children
comparison
equal deleted inserted replaced
279:206374eed2fb 284:8a3af19ae272
54 54
55 n_ins = 32*32 55 n_ins = 32*32
56 n_outs = 62 # 10 digits, 26*2 (lower, capitals) 56 n_outs = 62 # 10 digits, 26*2 (lower, capitals)
57 57
58 examples_per_epoch = NIST_ALL_TRAIN_SIZE 58 examples_per_epoch = NIST_ALL_TRAIN_SIZE
59 if rtt:
60 examples_per_epoch = rtt
59 61
60 series = create_series(state.num_hidden_layers) 62 series = create_series(state.num_hidden_layers)
61 63
62 print "Creating optimizer with state, ", state 64 print "Creating optimizer with state, ", state
63 65
64 optimizer = SdaSgdOptimizer(dataset=datasets.nist_all(), 66 dataset = None
67 if rtt:
68 dataset = datasets.nist_all(maxsize=rtt)
69 else:
70 dataset = datasets.nist_all()
71
72 optimizer = SdaSgdOptimizer(dataset=dataset,
65 hyperparameters=state, \ 73 hyperparameters=state, \
66 n_ins=n_ins, n_outs=n_outs,\ 74 n_ins=n_ins, n_outs=n_outs,\
67 examples_per_epoch=examples_per_epoch, \ 75 examples_per_epoch=examples_per_epoch, \
68 series=series, 76 series=series,
69 max_minibatches=rtt) 77 save_params=SAVE_PARAMS)
70 78
71 optimizer.pretrain(datasets.nist_all()) 79 optimizer.pretrain(dataset)
72 channel.save() 80 channel.save()
73 81
74 optimizer.finetune(datasets.nist_all()) 82 optimizer.finetune(dataset)
75 channel.save() 83 channel.save()
76 84
77 return channel.COMPLETE 85 return channel.COMPLETE
78 86
79 # These Series objects are used to save various statistics 87 # These Series objects are used to save various statistics