Mercurial > pylearn
comparison pylearn/algorithms/mcRBM.py @ 992:30b7c4defb6c
mcRBM - it works and committing it is taking forever... lets try this approach
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 24 Aug 2010 14:52:09 -0400 |
parents | d68828c98c38 |
children | 88107ec01ce8 |
comparison
equal
deleted
inserted
replaced
991:d68828c98c38 | 992:30b7c4defb6c |
---|---|
254 """ | 254 """ |
255 (U,W,a,b,c) = rbm | 255 (U,W,a,b,c) = rbm |
256 unit_v = v / (TT.sqrt(TT.mean(v**2, axis=1)+small)).dimshuffle(0,'x') # adjust row norm | 256 unit_v = v / (TT.sqrt(TT.mean(v**2, axis=1)+small)).dimshuffle(0,'x') # adjust row norm |
257 return b - 0.5 * dot(unit_v, U)**2 | 257 return b - 0.5 * dot(unit_v, U)**2 |
258 | 258 |
259 def free_energy_given_v(rbm, v): | 259 def free_energy_terms_given_v(rbm, v): |
260 """Returns theano expression for free energy of visible vector `v` in an mcRBM | 260 """Returns theano expression for the terms that are added to form the free energy of |
261 | 261 visible vector `v` in an mcRBM. |
262 An mcRBM is parametrized | 262 |
263 by `U`, `W`, `b`, `c`. | 263 1. Free energy related to covariance hiddens |
264 See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters. | 264 2. Free energy related to mean hiddens |
265 | 265 3. Free energy related to L2-Norm of `v` |
266 | 266 4. Free energy related to projection of `v` onto biases `a` |
267 The free energy of v is what we need for learning and hybrid Monte-carlo negative-phase | |
268 sampling. | |
269 | |
270 """ | 267 """ |
271 U, W, a, b, c = rbm | 268 U, W, a, b, c = rbm |
272 t0 = -TT.sum(TT.nnet.softplus(hidden_cov_units_preactivation_given_v(rbm, v)),axis=1) | 269 t0 = -TT.sum(TT.nnet.softplus(hidden_cov_units_preactivation_given_v(rbm, v)),axis=1) |
273 t1 = -TT.sum(TT.nnet.softplus(c + dot(v,W)), axis=1) | 270 t1 = -TT.sum(TT.nnet.softplus(c + dot(v,W)), axis=1) |
274 t2 = 0.5 * TT.sum(v**2, axis=1) | 271 t2 = 0.5 * TT.sum(v**2, axis=1) |
275 t3 = -TT.dot(v, a) | 272 t3 = -TT.dot(v, a) |
276 return t0 + t1 + t2 + t3, (t0, t1, t2, t3) | 273 return [t0, t1, t2, t3] |
274 | |
275 def free_energy_given_v(rbm, v): | |
276 """Returns theano expression for free energy of visible vector `v` in an mcRBM | |
277 """ | |
278 return sum(free_energy_terms_given_v(rbm,v)) | |
277 | 279 |
278 def contrastive_gradient(rbm, pos_v, neg_v, U_l1_penalty=0, W_l1_penalty=0): | 280 def contrastive_gradient(rbm, pos_v, neg_v, U_l1_penalty=0, W_l1_penalty=0): |
279 """Return a list of gradient expressions for the rbm parameters | 281 """Return a list of gradient expressions for the rbm parameters |
280 | 282 |
281 :param pos_v: positive-phase sample of visible units | 283 :param pos_v: positive-phase sample of visible units |
361 return HMC_sampler( | 363 return HMC_sampler( |
362 positions = [as_shared( | 364 positions = [as_shared( |
363 np.random.RandomState(seed^20893).randn( | 365 np.random.RandomState(seed^20893).randn( |
364 n_particles, | 366 n_particles, |
365 self.n_visible ))], | 367 self.n_visible ))], |
366 energy_fn = lambda p : self.free_energy_given_v(p[0]), | 368 energy_fn = lambda p : free_energy_given_v(self.params, p[0]), |
367 seed=seed) | 369 seed=seed) |
368 | 370 |
369 def free_energy_given_v(self, v, extra=False): | 371 def free_energy_given_v(self, v, extra=False): |
372 assert 0 | |
370 rval = free_energy_given_v(self.params, v) | 373 rval = free_energy_given_v(self.params, v) |
371 if extra: | 374 if extra: |
372 return rval | 375 return rval |
373 else: | 376 else: |
374 return rval[0] | 377 return rval[0] |
375 | 378 |
376 def contrastive_gradient(self, *args, **kwargs) | 379 def contrastive_gradient(self, *args, **kwargs): |
377 """Return a list of gradient expressions for self.params | 380 """Return a list of gradient expressions for self.params |
378 | 381 |
379 See `contrastive_gradient` for parameters. | 382 :param pos_v: positive-phase sample of visible units |
383 :param neg_v: negative-phase sample of visible units | |
380 """ | 384 """ |
381 return contrastive_gradient(self.params, *args, **kwargs) | 385 return contrastive_gradient(self.params, *args, **kwargs) |
382 | 386 |
383 | 387 |
384 if __name__ == '__main__': | 388 if __name__ == '__main__': |
392 demodata = scipy.io.loadmat(os.path.join(pylearn.datasets.config.data_root(),'image_patches', 'mcRBM', 'training_colorpatches_16x16_demo.mat')) | 396 demodata = scipy.io.loadmat(os.path.join(pylearn.datasets.config.data_root(),'image_patches', 'mcRBM', 'training_colorpatches_16x16_demo.mat')) |
393 else: | 397 else: |
394 R,C= 16,16 # the size of image patches | 398 R,C= 16,16 # the size of image patches |
395 n_patches=100000 | 399 n_patches=100000 |
396 | 400 |
397 n_train_iters=30000 | 401 n_train_iters=5000 |
398 | 402 |
399 n_burnin_steps=10000 | 403 n_burnin_steps=10000 |
400 | 404 |
401 l1_penalty=1e-3 | 405 l1_penalty=1e-3 |
402 no_l1_epochs = 10 | 406 no_l1_epochs = 10 |
403 effective_l1_penalty=0.0 | 407 effective_l1_penalty=0.0 |
404 | 408 |
405 epoch_size=50000 | 409 epoch_size=n_patches |
406 batchsize = 128 | 410 batchsize = 128 |
407 lr = 0.075 / batchsize | 411 lr = 0.075 / batchsize |
408 s_lr = TT.scalar() | 412 s_lr = TT.scalar() |
409 s_l1_penalty=TT.scalar() | 413 s_l1_penalty=TT.scalar() |
410 n_K=256 | 414 n_K=256 |
418 ) | 422 ) |
419 | 423 |
420 sampler = rbm.hmc_sampler(n_particles=batchsize) | 424 sampler = rbm.hmc_sampler(n_particles=batchsize) |
421 | 425 |
422 def l2(X): | 426 def l2(X): |
423 return (X**2).sum() | 427 return numpy.sqrt((X**2).sum()) |
424 def tile(X, fname): | 428 def tile(X, fname): |
425 if dataset == 'MAR': | 429 if dataset == 'MAR': |
426 X = np.dot(X, demodata['invpcatransf'].T) | 430 X = np.dot(X, demodata['invpcatransf'].T) |
427 R=16 | 431 R=16 |
428 C=16 | 432 C=16 |
448 | 452 |
449 sys.exit() | 453 sys.exit() |
450 | 454 |
451 batch_idx = TT.iscalar() | 455 batch_idx = TT.iscalar() |
452 | 456 |
453 if 0: | 457 if dataset == 'MAR': |
458 op = TensorFnDataset(floatX, | |
459 bcast=(False,), | |
460 fn=load_mcRBM_demo_patches, | |
461 single_shape=(105,)) | |
462 train_batch = op((batch_idx * batchsize + np.arange(batchsize))%n_patches) | |
463 else: | |
454 from pylearn.dataset_ops import image_patches | 464 from pylearn.dataset_ops import image_patches |
455 train_batch = image_patches.image_patches( | 465 train_batch = image_patches.image_patches( |
456 s_idx = (batch_idx * batchsize + np.arange(batchsize)), | 466 s_idx = (batch_idx * batchsize + np.arange(batchsize)), |
457 dims = (n_patches,R,C), | 467 dims = (n_patches,R,C), |
458 center=True, | 468 center=True, |
459 unitvar=True, | 469 unitvar=True, |
460 dtype=floatX, | 470 dtype=floatX, |
461 rasterized=True) | 471 rasterized=True) |
462 else: | |
463 op = TensorFnDataset(floatX, | |
464 bcast=(False,), | |
465 fn=load_mcRBM_demo_patches, | |
466 single_shape=(105,)) | |
467 train_batch = op((batch_idx * batchsize + np.arange(batchsize))%n_patches) | |
468 | 472 |
469 imgs_fn = function([batch_idx], outputs=train_batch) | 473 imgs_fn = function([batch_idx], outputs=train_batch) |
470 | 474 |
471 grads = rbm.contrastive_gradient( | 475 grads = rbm.contrastive_gradient( |
472 pos_v=train_batch, | 476 pos_v=train_batch, |
473 neg_v=sampler.positions[0], | 477 neg_v=sampler.positions[0], |
474 U_l1_penalty=s_l1_penalty, | 478 U_l1_penalty=s_l1_penalty, |
475 W_l1_penalty=s_l1_penalty) | 479 W_l1_penalty=s_l1_penalty) |
480 sgd_ups = sgd_updates( | |
481 rbm.params, | |
482 grads, | |
483 lr=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ]) | |
476 | 484 |
477 learn_fn = function([batch_idx, s_lr, s_l1_penalty], | 485 learn_fn = function([batch_idx, s_lr, s_l1_penalty], |
478 outputs=[ | 486 outputs=[ |
479 grads[0].norm(2), | 487 grads[0].norm(2), |
480 rbm.free_energy_given_v(train_batch).sum(), | 488 (sgd_ups[0][1] - sgd_ups[0][0]).norm(2), |
481 rbm.free_energy_given_v(train_batch,extra=1)[1][0].sum(), | 489 (sgd_ups[1][1] - sgd_ups[1][0]).norm(2), |
482 rbm.free_energy_given_v(train_batch,extra=1)[1][1].sum(), | |
483 rbm.free_energy_given_v(train_batch,extra=1)[1][2].sum(), | |
484 rbm.free_energy_given_v(train_batch,extra=1)[1][3].sum(), | |
485 ], | 490 ], |
486 updates = sgd_updates( | 491 updates = sgd_ups) |
487 rbm.params, | 492 #rbm.free_energy_given_v(train_batch).sum(), |
488 grads, | 493 #rbm.free_energy_given_v(train_batch,extra=1)[1][0].sum(), |
489 lr=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ])) | 494 #rbm.free_energy_given_v(train_batch,extra=1)[1][1].sum(), |
490 theano.printing.pydotprint(learn_fn, 'learn_fn.png') | 495 #rbm.free_energy_given_v(train_batch,extra=1)[1][2].sum(), |
496 #rbm.free_energy_given_v(train_batch,extra=1)[1][3].sum(), | |
497 theano.printing.pydotprint(function([batch_idx, s_l1_penalty], grads[0]), 'grads0.png') | |
491 | 498 |
492 print "Learning..." | 499 print "Learning..." |
493 normVF=1 | 500 normVF=1 |
501 last_epoch = -1 | |
494 for jj in xrange(n_train_iters): | 502 for jj in xrange(n_train_iters): |
495 | 503 epoch = jj*batchsize / epoch_size |
496 print_jj = ((1 and jj < 100) | 504 |
497 or (0 and jj < 100 and 0==jj%10) | 505 print_jj = epoch != last_epoch |
498 or (jj < 1000 and 0==jj%100) | 506 last_epoch = epoch |
499 or (1 and jj < 10000 and 0==jj%1000)) | 507 |
500 | 508 if epoch > 10: |
509 break | |
501 | 510 |
502 if print_jj: | 511 if print_jj: |
503 tile(imgs_fn(jj), "imgs_%06i.png"%jj) | 512 tile(imgs_fn(jj), "imgs_%06i.png"%jj) |
504 tile(sampler.positions[0].value, "sample_%06i.png"%jj) | 513 tile(sampler.positions[0].value, "sample_%06i.png"%jj) |
505 tile(rbm.U.value.T, "U_%06i.png"%jj) | 514 tile(rbm.U.value.T, "U_%06i.png"%jj) |
506 tile(rbm.W.value.T, "W_%06i.png"%jj) | 515 tile(rbm.W.value.T, "W_%06i.png"%jj) |
507 | 516 |
508 print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize), | 517 print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize) |
518 | |
509 print 'l2(U)', l2(rbm.U.value), | 519 print 'l2(U)', l2(rbm.U.value), |
510 print 'l2(W)', l2(rbm.W.value), | 520 print 'l2(W)', l2(rbm.W.value) |
521 | |
511 print 'U min max', rbm.U.value.min(), rbm.U.value.max(), | 522 print 'U min max', rbm.U.value.min(), rbm.U.value.max(), |
512 print 'W min max', rbm.W.value.min(), rbm.W.value.max(), | 523 print 'W min max', rbm.W.value.min(), rbm.W.value.max(), |
513 print 'a min max', rbm.a.value.min(), rbm.a.value.max(), | 524 print 'a min max', rbm.a.value.min(), rbm.a.value.max(), |
514 print 'b min max', rbm.b.value.min(), rbm.b.value.max(), | 525 print 'b min max', rbm.b.value.min(), rbm.b.value.max(), |
515 print 'c min max', rbm.c.value.min(), rbm.c.value.max(), | 526 print 'c min max', rbm.c.value.min(), rbm.c.value.max() |
516 | 527 |
517 print 'parts min', sampler.positions[0].value.min(), | 528 print 'parts min', sampler.positions[0].value.min(), |
518 print 'max',sampler.positions[0].value.max(), | 529 print 'max',sampler.positions[0].value.max(), |
519 print 'HMC step', sampler.stepsize, | 530 print 'HMC step', sampler.stepsize, |
520 print 'arate', sampler.avg_acceptance_rate | 531 print 'arate', sampler.avg_acceptance_rate |
524 l2_of_Ugrad = learn_fn(jj, | 535 l2_of_Ugrad = learn_fn(jj, |
525 lr/max(1, jj/(20*epoch_size/batchsize)), | 536 lr/max(1, jj/(20*epoch_size/batchsize)), |
526 effective_l1_penalty) | 537 effective_l1_penalty) |
527 | 538 |
528 if print_jj: | 539 if print_jj: |
529 print 'l2(gU)', float(l2_of_Ugrad[0]), | 540 print 'l2(U_grad)', float(l2_of_Ugrad[0]), |
530 print 'FE+', float(l2_of_Ugrad[1]), | 541 print 'l2(U_inc)', float(l2_of_Ugrad[1]), |
531 print 'FE+[0]', float(l2_of_Ugrad[2]), | 542 print 'l2(W_inc)', float(l2_of_Ugrad[2]), |
532 print 'FE+[1]', float(l2_of_Ugrad[3]), | 543 #print 'FE+', float(l2_of_Ugrad[2]), |
533 print 'FE+[2]', float(l2_of_Ugrad[4]), | 544 #print 'FE+[0]', float(l2_of_Ugrad[3]), |
534 print 'FE+[3]', float(l2_of_Ugrad[5]), | 545 #print 'FE+[1]', float(l2_of_Ugrad[4]), |
546 #print 'FE+[2]', float(l2_of_Ugrad[5]), | |
547 #print 'FE+[3]', float(l2_of_Ugrad[6]) | |
535 | 548 |
536 if jj == no_l1_epochs * epoch_size/batchsize: | 549 if jj == no_l1_epochs * epoch_size/batchsize: |
537 print "Activating L1 weight decay" | 550 print "Activating L1 weight decay" |
538 effective_l1_penalty = 1e-3 | 551 effective_l1_penalty = 1e-3 |
539 | 552 |
540 if 0: | 553 # weird normalization technique... |
541 rbm.U.value = numpy_project_onto_ball(rbm.U.value.T).T | 554 # It constrains all the columns of the matrix to have the same length |
542 else: | 555 # But the matrix itself is re-scaled to have an arbitrary abslute size. |
543 # weird normalization technique... | 556 U = rbm.U.value |
544 # It constrains all the columns of the matrix to have the same length | 557 U_norms = np.sqrt((U*U).sum(axis=0)) |
545 # But the matrix itself is re-scaled to have an arbitrary abslute size. | 558 assert len(U_norms) == n_F |
546 U = rbm.U.value | 559 normVF = .95 * normVF + .05 * np.mean(U_norms) |
547 U_norms = np.sqrt((U*U).sum(axis=0)) | 560 rbm.U.value = rbm.U.value * normVF/U_norms |
548 assert len(U_norms) == n_F | |
549 normVF = .95 * normVF + .05 * np.mean(U_norms) | |
550 rbm.U.value = rbm.U.value * normVF/U_norms | |
551 | 561 |
552 | 562 |
553 # | 563 # |
554 # | 564 # |
555 # Marc'Aurelio Ranzato's code | 565 # Marc'Aurelio Ranzato's code |