comparison pylearn/algorithms/mcRBM.py @ 1275:f0129e37a8ef

mcRBM - changed params from lambda to method for pickling
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 08 Sep 2010 13:18:13 -0400
parents 7bb5dd98e671
children 1817485d586d
comparison
equal deleted inserted replaced
1274:9d5905d6d879 1275:f0129e37a8ef
411 inputs = [v], 411 inputs = [v],
412 outputs = list(self.expected_h_g_given_v(v)), 412 outputs = list(self.expected_h_g_given_v(v)),
413 params = [self.U, self.W, self.b, self.c], 413 params = [self.U, self.W, self.b, self.c],
414 ) 414 )
415 415
416 def params(self):
417 """Return the elements of [U,W,a,b,c] that are shared variables
418
419 WRITEME : a *prescriptive* definition of this method suitable for mention in the API
420 doc.
421
422 """
423 return list(self._params)
424
416 @classmethod 425 @classmethod
417 def alloc(cls, n_I, n_K, n_J, rng = 8923402190, 426 def alloc(cls, n_I, n_K, n_J, rng = 8923402190,
418 U_range=0.02, 427 U_range=0.02,
419 W_range=0.05, 428 W_range=0.05,
420 a_ival=0, 429 a_ival=0,
438 U = sharedX(U_range * rng.randn(n_I, n_K),'U'), 447 U = sharedX(U_range * rng.randn(n_I, n_K),'U'),
439 W = sharedX(W_range * rng.randn(n_I, n_J),'W'), 448 W = sharedX(W_range * rng.randn(n_I, n_J),'W'),
440 a = sharedX(np.ones(n_I)*a_ival,'a'), 449 a = sharedX(np.ones(n_I)*a_ival,'a'),
441 b = sharedX(np.ones(n_K)*b_ival,'b'), 450 b = sharedX(np.ones(n_K)*b_ival,'b'),
442 c = sharedX(np.ones(n_J)*c_ival,'c'),) 451 c = sharedX(np.ones(n_J)*c_ival,'c'),)
443 452 rval._params = [rval.U, rval.W, rval.a, rval.b, rval.c]
444 rval.params = lambda : [rval.U, rval.W, rval.a, rval.b, rval.c]
445 return rval 453 return rval
446 454
447 class mcRBMTrainer(object): 455 class mcRBMTrainer(object):
448 """Light-weight class encapsulating math for mcRBM training 456 """Light-weight class encapsulating math for mcRBM training
449 457
563 import pylearn.algorithms.tests.test_mcRBM 571 import pylearn.algorithms.tests.test_mcRBM
564 rbm,smplr = pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010( 572 rbm,smplr = pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010(
565 as_unittest=False, 573 as_unittest=False,
566 n_train_iters=10) 574 n_train_iters=10)
567 import cPickle 575 import cPickle
576 print ''
577 print 'Saving rbm...'
568 cPickle.dump(rbm, open('mcRBM.rbm.pkl', 'w'), -1) 578 cPickle.dump(rbm, open('mcRBM.rbm.pkl', 'w'), -1)
579 print 'Saving sampler...'
569 cPickle.dump(smplr, open('mcRBM.smplr.pkl', 'w'), -1) 580 cPickle.dump(smplr, open('mcRBM.smplr.pkl', 'w'), -1)
570 581