Mercurial > pylearn
comparison pylearn/sandbox/scan_inputs_groups.py @ 701:113946723973
fixed bug of scan_inputs_groups
author | Xavier Glorot <glorotxa@iro.umontreal.ca> |
---|---|
date | Tue, 19 May 2009 19:00:34 -0400 |
parents | 33e13d8ca7d3 |
children | 0eae6d5315b5 |
comparison
equal
deleted
inserted
replaced
700:33e13d8ca7d3 | 701:113946723973 |
---|---|
344 z[0][0] = numpy.zeros((self.m.hidt.shape[1],self.m.hidt.shape[0])) | 344 z[0][0] = numpy.zeros((self.m.hidt.shape[1],self.m.hidt.shape[0])) |
345 | 345 |
346 for i in range((len(args)-4)): | 346 for i in range((len(args)-4)): |
347 if not zcalc[i]: | 347 if not zcalc[i]: |
348 shp = args[3+i].shape | 348 shp = args[3+i].shape |
349 z[i+1][0] = numpy.zeros(shp[0]) | 349 z[i+1][0] = numpy.zeros(shp) |
350 | 350 |
351 | 351 |
352 def __hash__(self): | 352 def __hash__(self): |
353 return hash(ScanDotDecGrad)^87445 | 353 return hash(ScanDotDecGrad)^87445 |
354 | 354 |
527 weights_list = Checkweights_list(weights_list) | 527 weights_list = Checkweights_list(weights_list) |
528 return Apply(self, [idx_list] + weights_list, [T.iscalar() for i in range(len(weights_list))]) | 528 return Apply(self, [idx_list] + weights_list, [T.iscalar() for i in range(len(weights_list))]) |
529 | 529 |
530 def perform(self, node, args, z): | 530 def perform(self, node, args, z): |
531 if self.encbool: | 531 if self.encbool: |
532 idx_list = args[0][args[0]>0] | 532 idx_list = args[0] |
533 dim = 1 | 533 dim = 1 |
534 else: | 534 else: |
535 idx_list = abs(args[0][args[0] != 0]) | 535 idx_list = abs(args[0]) |
536 dim = 0 | 536 dim = 0 |
537 n_hid = args[1].shape[dim] | 537 n_hid = args[1].shape[dim] |
538 | 538 |
539 if max(idx_list) >= (len(args)-1)+1 : | 539 if max(idx_list) >= (len(args)-1)+1 : |
540 raise NotImplementedError('index superior to weights list length',idx_listdec) | 540 raise NotImplementedError('index superior to weights list length',idx_listdec) |
541 for i in range(len(args)-1): | 541 for i in range(len(args)-1): |
542 if args[1+i].shape[dim] != n_hid: | 542 if args[1+i].shape[dim] != n_hid: |
543 raise NotImplementedError('different length of hidden in the encoding weights list',args[1+i].shape) | 543 raise NotImplementedError('different length of hidden in the encoding weights list',args[1+i].shape) |