comparison pylearn/sandbox/scan_inputs_groups.py @ 701:113946723973

fixed bug of scan_inputs_groups
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Tue, 19 May 2009 19:00:34 -0400
parents 33e13d8ca7d3
children 0eae6d5315b5
comparison
equal deleted inserted replaced
700:33e13d8ca7d3 701:113946723973
344 z[0][0] = numpy.zeros((self.m.hidt.shape[1],self.m.hidt.shape[0])) 344 z[0][0] = numpy.zeros((self.m.hidt.shape[1],self.m.hidt.shape[0]))
345 345
346 for i in range((len(args)-4)): 346 for i in range((len(args)-4)):
347 if not zcalc[i]: 347 if not zcalc[i]:
348 shp = args[3+i].shape 348 shp = args[3+i].shape
349 z[i+1][0] = numpy.zeros(shp[0]) 349 z[i+1][0] = numpy.zeros(shp)
350 350
351 351
352 def __hash__(self): 352 def __hash__(self):
353 return hash(ScanDotDecGrad)^87445 353 return hash(ScanDotDecGrad)^87445
354 354
527 weights_list = Checkweights_list(weights_list) 527 weights_list = Checkweights_list(weights_list)
528 return Apply(self, [idx_list] + weights_list, [T.iscalar() for i in range(len(weights_list))]) 528 return Apply(self, [idx_list] + weights_list, [T.iscalar() for i in range(len(weights_list))])
529 529
530 def perform(self, node, args, z): 530 def perform(self, node, args, z):
531 if self.encbool: 531 if self.encbool:
532 idx_list = args[0][args[0]>0] 532 idx_list = args[0]
533 dim = 1 533 dim = 1
534 else: 534 else:
535 idx_list = abs(args[0][args[0] != 0]) 535 idx_list = abs(args[0])
536 dim = 0 536 dim = 0
537 n_hid = args[1].shape[dim] 537 n_hid = args[1].shape[dim]
538 538
539 if max(idx_list) >= (len(args)-1)+1 : 539 if max(idx_list) >= (len(args)-1)+1 :
540 raise NotImplementedError('index superior to weights list length',idx_listdec) 540 raise NotImplementedError('index superior to weights list length',idx_listdec)
541 for i in range(len(args)-1): 541 for i in range(len(args)-1):
542 if args[1+i].shape[dim] != n_hid: 542 if args[1+i].shape[dim] != n_hid:
543 raise NotImplementedError('different length of hidden in the encoding weights list',args[1+i].shape) 543 raise NotImplementedError('different length of hidden in the encoding weights list',args[1+i].shape)