Mercurial > pylearn
comparison pylearn/algorithms/tests/test_mcRBM.py @ 1512:7f166d01bf8e
Remove deprecation warning.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Mon, 12 Sep 2011 11:59:55 -0400 |
parents | b709f6b53b17 |
children | 9d21919e2332 |
comparison
equal
deleted
inserted
replaced
1511:9ffe5d6faee3 | 1512:7f166d01bf8e |
---|---|
133 | 133 |
134 print_jj = epoch != last_epoch | 134 print_jj = epoch != last_epoch |
135 last_epoch = epoch | 135 last_epoch = epoch |
136 | 136 |
137 if as_unittest and epoch == 5: | 137 if as_unittest and epoch == 5: |
138 U = rbm.U.value | 138 U = rbm.U.get_value(borrow=True) |
139 W = rbm.W.value | 139 W = rbm.W.get_value(borrow=True) |
140 def allclose(a,b): | 140 def allclose(a,b): |
141 return numpy.allclose(a,b,rtol=1.01,atol=1e-3) | 141 return numpy.allclose(a,b,rtol=1.01,atol=1e-3) |
142 print "" | 142 print "" |
143 print "--------------" | 143 print "--------------" |
144 print "assert allclose(l2(U), %f)"%l2(U) | 144 print "assert allclose(l2(U), %f)"%l2(U) |
166 tile(rbm.U.value.T, "U_%06i.png"%jj) | 166 tile(rbm.U.value.T, "U_%06i.png"%jj) |
167 tile(rbm.W.value.T, "W_%06i.png"%jj) | 167 tile(rbm.W.value.T, "W_%06i.png"%jj) |
168 | 168 |
169 print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize) | 169 print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize) |
170 | 170 |
171 print 'l2(U)', l2(rbm.U.value), | 171 print 'l2(U)', l2(rbm.U.get_value(borrow=True)), |
172 print 'l2(W)', l2(rbm.W.value), | 172 print 'l2(W)', l2(rbm.W.get_value(borrow=True)), |
173 print 'l1_penalty', | 173 print 'l1_penalty', |
174 try: | 174 try: |
175 print trainer.effective_l1_penalty.value | 175 print trainer.effective_l1_penalty.get_value(borrow=True) |
176 except: | 176 except: |
177 print trainer.effective_l1_penalty | 177 print trainer.effective_l1_penalty |
178 | 178 |
179 print 'U min max', rbm.U.value.min(), rbm.U.value.max(), | 179 print 'U min max', rbm.U.get_value(borrow=True).min(), rbm.U.get_value(borrow=True).max(), |
180 print 'W min max', rbm.W.value.min(), rbm.W.value.max(), | 180 print 'W min max', rbm.W.get_value(borrow=True).min(), rbm.W.get_value(borrow=True).max(), |
181 print 'a min max', rbm.a.value.min(), rbm.a.value.max(), | 181 print 'a min max', rbm.a.get_value(borrow=True).min(), rbm.a.get_value(borrow=True).max(), |
182 print 'b min max', rbm.b.value.min(), rbm.b.value.max(), | 182 print 'b min max', rbm.b.get_value(borrow=True).min(), rbm.b.get_value(borrow=True).max(), |
183 print 'c min max', rbm.c.value.min(), rbm.c.value.max() | 183 print 'c min max', rbm.c.get_value(borrow=True).min(), rbm.c.get_value(borrow=True).max() |
184 | 184 |
185 if persistent_chains: | 185 if persistent_chains: |
186 print 'parts min', smplr.positions.value.min(), | 186 print 'parts min', smplr.positions.get_value(borrow=True).min(), |
187 print 'max',smplr.positions.value.max(), | 187 print 'max',smplr.positions.get_value(borrow=True).max(), |
188 print 'HMC step', smplr.stepsize.value, | 188 print 'HMC step', smplr.stepsize.get_value(borrow=True), |
189 print 'arate', smplr.avg_acceptance_rate.value | 189 print 'arate', smplr.avg_acceptance_rate.get_value(borrow=True) |
190 | 190 |
191 | 191 |
192 l2_of_Ugrad = learn_fn(jj) | 192 l2_of_Ugrad = learn_fn(jj) |
193 | 193 |
194 if persistent_chains and print_jj: | 194 if persistent_chains and print_jj: |
273 tile(rbm.U.value.T, "U_%06i.png"%ii) | 273 tile(rbm.U.value.T, "U_%06i.png"%ii) |
274 tile(rbm.W.value.T, "W_%06i.png"%ii) | 274 tile(rbm.W.value.T, "W_%06i.png"%ii) |
275 | 275 |
276 print 'saving samples', ii, 'epoch', i_epoch, i_batch | 276 print 'saving samples', ii, 'epoch', i_epoch, i_batch |
277 | 277 |
278 print 'l2(U)', l2(rbm.U.value), | 278 print 'l2(U)', l2(rbm.U.get_value(borrow=True)), |
279 print 'l2(W)', l2(rbm.W.value), | 279 print 'l2(W)', l2(rbm.W.get_value(borrow=True)), |
280 print 'l1_penalty', | 280 print 'l1_penalty', |
281 try: | 281 try: |
282 print trainer.effective_l1_penalty.value | 282 print trainer.effective_l1_penalty.get_value(borrow=True) |
283 except: | 283 except: |
284 print trainer.effective_l1_penalty | 284 print trainer.effective_l1_penalty |
285 | 285 |
286 print 'U min max', rbm.U.value.min(), rbm.U.value.max(), | 286 print 'U min max', rbm.U.get_value(borrow=True).min(), rbm.U.get_value(borrow=True).max(), |
287 print 'W min max', rbm.W.value.min(), rbm.W.value.max(), | 287 print 'W min max', rbm.W.get_value(borrow=True).min(), rbm.W.get_value(borrow=True).max(), |
288 print 'a min max', rbm.a.value.min(), rbm.a.value.max(), | 288 print 'a min max', rbm.a.get_value(borrow=True).min(), rbm.a.get_value(borrow=True).max(), |
289 print 'b min max', rbm.b.value.min(), rbm.b.value.max(), | 289 print 'b min max', rbm.b.get_value(borrow=True).min(), rbm.b.get_value(borrow=True).max(), |
290 print 'c min max', rbm.c.value.min(), rbm.c.value.max() | 290 print 'c min max', rbm.c.get_value(borrow=True).min(), rbm.c.get_value(borrow=True).max() |
291 | 291 |
292 print 'HMC step', smplr.stepsize.value, | 292 print 'HMC step', smplr.stepsize.get_value(borrow=True), |
293 print 'arate', smplr.avg_acceptance_rate.value | 293 print 'arate', smplr.avg_acceptance_rate.get_value(borrow=True) |
294 print 'P min max', rbm.P.value.min(), rbm.P.value.max(), | 294 print 'P min max', rbm.P.get_value(borrow=True).min(), rbm.P.get_value(borrow=True).max(), |
295 print 'P_lr', trainer.p_lr.value | 295 print 'P_lr', trainer.p_lr.get_value(borrow=True) |
296 print '' | 296 print '' |
297 print 'Saving rbm...' | 297 print 'Saving rbm...' |
298 cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl'%ii, 'w'), -1) | 298 cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl'%ii, 'w'), -1) |
299 | 299 |
300 ii += 1 | 300 ii += 1 |