Mercurial > pylearn
comparison pylearn/algorithms/mcRBM.py @ 1275:f0129e37a8ef
mcRBM - changed params from lambda to method for pickling
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 08 Sep 2010 13:18:13 -0400 |
parents | 7bb5dd98e671 |
children | 1817485d586d |
comparison
equal
deleted
inserted
replaced
1274:9d5905d6d879 | 1275:f0129e37a8ef |
---|---|
411 inputs = [v], | 411 inputs = [v], |
412 outputs = list(self.expected_h_g_given_v(v)), | 412 outputs = list(self.expected_h_g_given_v(v)), |
413 params = [self.U, self.W, self.b, self.c], | 413 params = [self.U, self.W, self.b, self.c], |
414 ) | 414 ) |
415 | 415 |
416 def params(self): | |
417 """Return the elements of [U,W,a,b,c] that are shared variables | |
418 | |
419 WRITEME : a *prescriptive* definition of this method suitable for mention in the API | |
420 doc. | |
421 | |
422 """ | |
423 return list(self._params) | |
424 | |
416 @classmethod | 425 @classmethod |
417 def alloc(cls, n_I, n_K, n_J, rng = 8923402190, | 426 def alloc(cls, n_I, n_K, n_J, rng = 8923402190, |
418 U_range=0.02, | 427 U_range=0.02, |
419 W_range=0.05, | 428 W_range=0.05, |
420 a_ival=0, | 429 a_ival=0, |
438 U = sharedX(U_range * rng.randn(n_I, n_K),'U'), | 447 U = sharedX(U_range * rng.randn(n_I, n_K),'U'), |
439 W = sharedX(W_range * rng.randn(n_I, n_J),'W'), | 448 W = sharedX(W_range * rng.randn(n_I, n_J),'W'), |
440 a = sharedX(np.ones(n_I)*a_ival,'a'), | 449 a = sharedX(np.ones(n_I)*a_ival,'a'), |
441 b = sharedX(np.ones(n_K)*b_ival,'b'), | 450 b = sharedX(np.ones(n_K)*b_ival,'b'), |
442 c = sharedX(np.ones(n_J)*c_ival,'c'),) | 451 c = sharedX(np.ones(n_J)*c_ival,'c'),) |
443 | 452 rval._params = [rval.U, rval.W, rval.a, rval.b, rval.c] |
444 rval.params = lambda : [rval.U, rval.W, rval.a, rval.b, rval.c] | |
445 return rval | 453 return rval |
446 | 454 |
447 class mcRBMTrainer(object): | 455 class mcRBMTrainer(object): |
448 """Light-weight class encapsulating math for mcRBM training | 456 """Light-weight class encapsulating math for mcRBM training |
449 | 457 |
563 import pylearn.algorithms.tests.test_mcRBM | 571 import pylearn.algorithms.tests.test_mcRBM |
564 rbm,smplr = pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010( | 572 rbm,smplr = pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010( |
565 as_unittest=False, | 573 as_unittest=False, |
566 n_train_iters=10) | 574 n_train_iters=10) |
567 import cPickle | 575 import cPickle |
576 print '' | |
577 print 'Saving rbm...' | |
568 cPickle.dump(rbm, open('mcRBM.rbm.pkl', 'w'), -1) | 578 cPickle.dump(rbm, open('mcRBM.rbm.pkl', 'w'), -1) |
579 print 'Saving sampler...' | |
569 cPickle.dump(smplr, open('mcRBM.smplr.pkl', 'w'), -1) | 580 cPickle.dump(smplr, open('mcRBM.smplr.pkl', 'w'), -1) |
570 | 581 |