comparison pylearn/algorithms/mcRBM.py @ 992:30b7c4defb6c

mcRBM - it works and committing it is taking forever... lets try this approach
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 14:52:09 -0400
parents d68828c98c38
children 88107ec01ce8
comparison
equal deleted inserted replaced
991:d68828c98c38 992:30b7c4defb6c
254 """ 254 """
255 (U,W,a,b,c) = rbm 255 (U,W,a,b,c) = rbm
256 unit_v = v / (TT.sqrt(TT.mean(v**2, axis=1)+small)).dimshuffle(0,'x') # adjust row norm 256 unit_v = v / (TT.sqrt(TT.mean(v**2, axis=1)+small)).dimshuffle(0,'x') # adjust row norm
257 return b - 0.5 * dot(unit_v, U)**2 257 return b - 0.5 * dot(unit_v, U)**2
258 258
259 def free_energy_given_v(rbm, v): 259 def free_energy_terms_given_v(rbm, v):
260 """Returns theano expression for free energy of visible vector `v` in an mcRBM 260 """Returns theano expression for the terms that are added to form the free energy of
261 261 visible vector `v` in an mcRBM.
262 An mcRBM is parametrized 262
263 by `U`, `W`, `b`, `c`. 263 1. Free energy related to covariance hiddens
264 See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters. 264 2. Free energy related to mean hiddens
265 265 3. Free energy related to L2-Norm of `v`
266 266 4. Free energy related to projection of `v` onto biases `a`
267 The free energy of v is what we need for learning and hybrid Monte-carlo negative-phase
268 sampling.
269
270 """ 267 """
271 U, W, a, b, c = rbm 268 U, W, a, b, c = rbm
272 t0 = -TT.sum(TT.nnet.softplus(hidden_cov_units_preactivation_given_v(rbm, v)),axis=1) 269 t0 = -TT.sum(TT.nnet.softplus(hidden_cov_units_preactivation_given_v(rbm, v)),axis=1)
273 t1 = -TT.sum(TT.nnet.softplus(c + dot(v,W)), axis=1) 270 t1 = -TT.sum(TT.nnet.softplus(c + dot(v,W)), axis=1)
274 t2 = 0.5 * TT.sum(v**2, axis=1) 271 t2 = 0.5 * TT.sum(v**2, axis=1)
275 t3 = -TT.dot(v, a) 272 t3 = -TT.dot(v, a)
276 return t0 + t1 + t2 + t3, (t0, t1, t2, t3) 273 return [t0, t1, t2, t3]
274
275 def free_energy_given_v(rbm, v):
276 """Returns theano expression for free energy of visible vector `v` in an mcRBM
277 """
278 return sum(free_energy_terms_given_v(rbm,v))
277 279
278 def contrastive_gradient(rbm, pos_v, neg_v, U_l1_penalty=0, W_l1_penalty=0): 280 def contrastive_gradient(rbm, pos_v, neg_v, U_l1_penalty=0, W_l1_penalty=0):
279 """Return a list of gradient expressions for the rbm parameters 281 """Return a list of gradient expressions for the rbm parameters
280 282
281 :param pos_v: positive-phase sample of visible units 283 :param pos_v: positive-phase sample of visible units
361 return HMC_sampler( 363 return HMC_sampler(
362 positions = [as_shared( 364 positions = [as_shared(
363 np.random.RandomState(seed^20893).randn( 365 np.random.RandomState(seed^20893).randn(
364 n_particles, 366 n_particles,
365 self.n_visible ))], 367 self.n_visible ))],
366 energy_fn = lambda p : self.free_energy_given_v(p[0]), 368 energy_fn = lambda p : free_energy_given_v(self.params, p[0]),
367 seed=seed) 369 seed=seed)
368 370
369 def free_energy_given_v(self, v, extra=False): 371 def free_energy_given_v(self, v, extra=False):
372 assert 0
370 rval = free_energy_given_v(self.params, v) 373 rval = free_energy_given_v(self.params, v)
371 if extra: 374 if extra:
372 return rval 375 return rval
373 else: 376 else:
374 return rval[0] 377 return rval[0]
375 378
376 def contrastive_gradient(self, *args, **kwargs) 379 def contrastive_gradient(self, *args, **kwargs):
377 """Return a list of gradient expressions for self.params 380 """Return a list of gradient expressions for self.params
378 381
379 See `contrastive_gradient` for parameters. 382 :param pos_v: positive-phase sample of visible units
383 :param neg_v: negative-phase sample of visible units
380 """ 384 """
381 return contrastive_gradient(self.params, *args, **kwargs) 385 return contrastive_gradient(self.params, *args, **kwargs)
382 386
383 387
384 if __name__ == '__main__': 388 if __name__ == '__main__':
392 demodata = scipy.io.loadmat(os.path.join(pylearn.datasets.config.data_root(),'image_patches', 'mcRBM', 'training_colorpatches_16x16_demo.mat')) 396 demodata = scipy.io.loadmat(os.path.join(pylearn.datasets.config.data_root(),'image_patches', 'mcRBM', 'training_colorpatches_16x16_demo.mat'))
393 else: 397 else:
394 R,C= 16,16 # the size of image patches 398 R,C= 16,16 # the size of image patches
395 n_patches=100000 399 n_patches=100000
396 400
397 n_train_iters=30000 401 n_train_iters=5000
398 402
399 n_burnin_steps=10000 403 n_burnin_steps=10000
400 404
401 l1_penalty=1e-3 405 l1_penalty=1e-3
402 no_l1_epochs = 10 406 no_l1_epochs = 10
403 effective_l1_penalty=0.0 407 effective_l1_penalty=0.0
404 408
405 epoch_size=50000 409 epoch_size=n_patches
406 batchsize = 128 410 batchsize = 128
407 lr = 0.075 / batchsize 411 lr = 0.075 / batchsize
408 s_lr = TT.scalar() 412 s_lr = TT.scalar()
409 s_l1_penalty=TT.scalar() 413 s_l1_penalty=TT.scalar()
410 n_K=256 414 n_K=256
418 ) 422 )
419 423
420 sampler = rbm.hmc_sampler(n_particles=batchsize) 424 sampler = rbm.hmc_sampler(n_particles=batchsize)
421 425
422 def l2(X): 426 def l2(X):
423 return (X**2).sum() 427 return numpy.sqrt((X**2).sum())
424 def tile(X, fname): 428 def tile(X, fname):
425 if dataset == 'MAR': 429 if dataset == 'MAR':
426 X = np.dot(X, demodata['invpcatransf'].T) 430 X = np.dot(X, demodata['invpcatransf'].T)
427 R=16 431 R=16
428 C=16 432 C=16
448 452
449 sys.exit() 453 sys.exit()
450 454
451 batch_idx = TT.iscalar() 455 batch_idx = TT.iscalar()
452 456
453 if 0: 457 if dataset == 'MAR':
458 op = TensorFnDataset(floatX,
459 bcast=(False,),
460 fn=load_mcRBM_demo_patches,
461 single_shape=(105,))
462 train_batch = op((batch_idx * batchsize + np.arange(batchsize))%n_patches)
463 else:
454 from pylearn.dataset_ops import image_patches 464 from pylearn.dataset_ops import image_patches
455 train_batch = image_patches.image_patches( 465 train_batch = image_patches.image_patches(
456 s_idx = (batch_idx * batchsize + np.arange(batchsize)), 466 s_idx = (batch_idx * batchsize + np.arange(batchsize)),
457 dims = (n_patches,R,C), 467 dims = (n_patches,R,C),
458 center=True, 468 center=True,
459 unitvar=True, 469 unitvar=True,
460 dtype=floatX, 470 dtype=floatX,
461 rasterized=True) 471 rasterized=True)
462 else:
463 op = TensorFnDataset(floatX,
464 bcast=(False,),
465 fn=load_mcRBM_demo_patches,
466 single_shape=(105,))
467 train_batch = op((batch_idx * batchsize + np.arange(batchsize))%n_patches)
468 472
469 imgs_fn = function([batch_idx], outputs=train_batch) 473 imgs_fn = function([batch_idx], outputs=train_batch)
470 474
471 grads = rbm.contrastive_gradient( 475 grads = rbm.contrastive_gradient(
472 pos_v=train_batch, 476 pos_v=train_batch,
473 neg_v=sampler.positions[0], 477 neg_v=sampler.positions[0],
474 U_l1_penalty=s_l1_penalty, 478 U_l1_penalty=s_l1_penalty,
475 W_l1_penalty=s_l1_penalty) 479 W_l1_penalty=s_l1_penalty)
480 sgd_ups = sgd_updates(
481 rbm.params,
482 grads,
483 lr=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ])
476 484
477 learn_fn = function([batch_idx, s_lr, s_l1_penalty], 485 learn_fn = function([batch_idx, s_lr, s_l1_penalty],
478 outputs=[ 486 outputs=[
479 grads[0].norm(2), 487 grads[0].norm(2),
480 rbm.free_energy_given_v(train_batch).sum(), 488 (sgd_ups[0][1] - sgd_ups[0][0]).norm(2),
481 rbm.free_energy_given_v(train_batch,extra=1)[1][0].sum(), 489 (sgd_ups[1][1] - sgd_ups[1][0]).norm(2),
482 rbm.free_energy_given_v(train_batch,extra=1)[1][1].sum(),
483 rbm.free_energy_given_v(train_batch,extra=1)[1][2].sum(),
484 rbm.free_energy_given_v(train_batch,extra=1)[1][3].sum(),
485 ], 490 ],
486 updates = sgd_updates( 491 updates = sgd_ups)
487 rbm.params, 492 #rbm.free_energy_given_v(train_batch).sum(),
488 grads, 493 #rbm.free_energy_given_v(train_batch,extra=1)[1][0].sum(),
489 lr=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ])) 494 #rbm.free_energy_given_v(train_batch,extra=1)[1][1].sum(),
490 theano.printing.pydotprint(learn_fn, 'learn_fn.png') 495 #rbm.free_energy_given_v(train_batch,extra=1)[1][2].sum(),
496 #rbm.free_energy_given_v(train_batch,extra=1)[1][3].sum(),
497 theano.printing.pydotprint(function([batch_idx, s_l1_penalty], grads[0]), 'grads0.png')
491 498
492 print "Learning..." 499 print "Learning..."
493 normVF=1 500 normVF=1
501 last_epoch = -1
494 for jj in xrange(n_train_iters): 502 for jj in xrange(n_train_iters):
495 503 epoch = jj*batchsize / epoch_size
496 print_jj = ((1 and jj < 100) 504
497 or (0 and jj < 100 and 0==jj%10) 505 print_jj = epoch != last_epoch
498 or (jj < 1000 and 0==jj%100) 506 last_epoch = epoch
499 or (1 and jj < 10000 and 0==jj%1000)) 507
500 508 if epoch > 10:
509 break
501 510
502 if print_jj: 511 if print_jj:
503 tile(imgs_fn(jj), "imgs_%06i.png"%jj) 512 tile(imgs_fn(jj), "imgs_%06i.png"%jj)
504 tile(sampler.positions[0].value, "sample_%06i.png"%jj) 513 tile(sampler.positions[0].value, "sample_%06i.png"%jj)
505 tile(rbm.U.value.T, "U_%06i.png"%jj) 514 tile(rbm.U.value.T, "U_%06i.png"%jj)
506 tile(rbm.W.value.T, "W_%06i.png"%jj) 515 tile(rbm.W.value.T, "W_%06i.png"%jj)
507 516
508 print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize), 517 print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)
518
509 print 'l2(U)', l2(rbm.U.value), 519 print 'l2(U)', l2(rbm.U.value),
510 print 'l2(W)', l2(rbm.W.value), 520 print 'l2(W)', l2(rbm.W.value)
521
511 print 'U min max', rbm.U.value.min(), rbm.U.value.max(), 522 print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
512 print 'W min max', rbm.W.value.min(), rbm.W.value.max(), 523 print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
513 print 'a min max', rbm.a.value.min(), rbm.a.value.max(), 524 print 'a min max', rbm.a.value.min(), rbm.a.value.max(),
514 print 'b min max', rbm.b.value.min(), rbm.b.value.max(), 525 print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
515 print 'c min max', rbm.c.value.min(), rbm.c.value.max(), 526 print 'c min max', rbm.c.value.min(), rbm.c.value.max()
516 527
517 print 'parts min', sampler.positions[0].value.min(), 528 print 'parts min', sampler.positions[0].value.min(),
518 print 'max',sampler.positions[0].value.max(), 529 print 'max',sampler.positions[0].value.max(),
519 print 'HMC step', sampler.stepsize, 530 print 'HMC step', sampler.stepsize,
520 print 'arate', sampler.avg_acceptance_rate 531 print 'arate', sampler.avg_acceptance_rate
524 l2_of_Ugrad = learn_fn(jj, 535 l2_of_Ugrad = learn_fn(jj,
525 lr/max(1, jj/(20*epoch_size/batchsize)), 536 lr/max(1, jj/(20*epoch_size/batchsize)),
526 effective_l1_penalty) 537 effective_l1_penalty)
527 538
528 if print_jj: 539 if print_jj:
529 print 'l2(gU)', float(l2_of_Ugrad[0]), 540 print 'l2(U_grad)', float(l2_of_Ugrad[0]),
530 print 'FE+', float(l2_of_Ugrad[1]), 541 print 'l2(U_inc)', float(l2_of_Ugrad[1]),
531 print 'FE+[0]', float(l2_of_Ugrad[2]), 542 print 'l2(W_inc)', float(l2_of_Ugrad[2]),
532 print 'FE+[1]', float(l2_of_Ugrad[3]), 543 #print 'FE+', float(l2_of_Ugrad[2]),
533 print 'FE+[2]', float(l2_of_Ugrad[4]), 544 #print 'FE+[0]', float(l2_of_Ugrad[3]),
534 print 'FE+[3]', float(l2_of_Ugrad[5]), 545 #print 'FE+[1]', float(l2_of_Ugrad[4]),
546 #print 'FE+[2]', float(l2_of_Ugrad[5]),
547 #print 'FE+[3]', float(l2_of_Ugrad[6])
535 548
536 if jj == no_l1_epochs * epoch_size/batchsize: 549 if jj == no_l1_epochs * epoch_size/batchsize:
537 print "Activating L1 weight decay" 550 print "Activating L1 weight decay"
538 effective_l1_penalty = 1e-3 551 effective_l1_penalty = 1e-3
539 552
540 if 0: 553 # weird normalization technique...
541 rbm.U.value = numpy_project_onto_ball(rbm.U.value.T).T 554 # It constrains all the columns of the matrix to have the same length
542 else: 555 # But the matrix itself is re-scaled to have an arbitrary abslute size.
543 # weird normalization technique... 556 U = rbm.U.value
544 # It constrains all the columns of the matrix to have the same length 557 U_norms = np.sqrt((U*U).sum(axis=0))
545 # But the matrix itself is re-scaled to have an arbitrary abslute size. 558 assert len(U_norms) == n_F
546 U = rbm.U.value 559 normVF = .95 * normVF + .05 * np.mean(U_norms)
547 U_norms = np.sqrt((U*U).sum(axis=0)) 560 rbm.U.value = rbm.U.value * normVF/U_norms
548 assert len(U_norms) == n_F
549 normVF = .95 * normVF + .05 * np.mean(U_norms)
550 rbm.U.value = rbm.U.value * normVF/U_norms
551 561
552 562
553 # 563 #
554 # 564 #
555 # Marc'Aurelio Ranzato's code 565 # Marc'Aurelio Ranzato's code