changeset 701:113946723973

fixed bug of scan_inputs_groups
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Tue, 19 May 2009 19:00:34 -0400
parents 33e13d8ca7d3
children f76079ba8d9a
files pylearn/sandbox/scan_inputs_groups.py
diffstat 1 files changed, 4 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sandbox/scan_inputs_groups.py	Tue May 19 11:23:54 2009 -0400
+++ b/pylearn/sandbox/scan_inputs_groups.py	Tue May 19 19:00:34 2009 -0400
@@ -346,7 +346,7 @@
 		for i in range((len(args)-4)):
 			if not zcalc[i]:
 				shp = args[3+i].shape
-				z[i+1][0] = numpy.zeros(shp[0])
+				z[i+1][0] = numpy.zeros(shp)
 		
 		
 	def __hash__(self):
@@ -529,13 +529,13 @@
 	
 	def perform(self, node, args, z):
 		if self.encbool:
-			idx_list = args[0][args[0]>0]
+			idx_list = args[0]
 			dim = 1
 		else:
-			idx_list = abs(args[0][args[0] != 0])
+			idx_list = abs(args[0])
 			dim = 0
 		n_hid = args[1].shape[dim]
-		
+
 		if max(idx_list) >= (len(args)-1)+1 :
 			raise NotImplementedError('index superior to weights list length',idx_listdec)
 		for i in range(len(args)-1):