changeset 700:33e13d8ca7d3

improved efficiency of scan_inputs_groups
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Tue, 19 May 2009 11:23:54 -0400
parents 6a56fcec4677
children 113946723973
files pylearn/sandbox/scan_inputs_groups.py
diffstat 1 files changed, 19 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sandbox/scan_inputs_groups.py	Fri May 15 17:52:24 2009 -0400
+++ b/pylearn/sandbox/scan_inputs_groups.py	Tue May 19 11:23:54 2009 -0400
@@ -16,7 +16,8 @@
 #								if inputslist[i]>0 it refers to Weightslist[indexlist[i]-1]
 #	*the 0 means that the second element of the input list will not be encoded neither decoded (it is remplaced by zeros)
 #		this is not efficient, so in this case it is better to give: [1,-3] and [inputslist[0],inputslist[2]]
-#		but it allows us to deal with empty lists: give indexlist = [.0] and inputlist=[[.0]]
+#		but it allows us to deal with empty lists: give indexlist = numpy.asarray([.0])
+#		and inputlist=numpy.zeros((batchsize,1))
 #	*when an index is negative it means that the input will not be used for encoding but we will still reconstruct it
 #		(auxiliary target as output)
 #								if inputslist[i]<0 it refers to Weightslist[-indexlist[i]-1]
@@ -125,7 +126,11 @@
 		
 	
 	def grad(self, args, gz):
-		return [None, None] + ScanDotEncGrad()(args,gz)
+		gradi = ScanDotEncGrad()(args,gz)
+		if type(gradi) != list:
+			return [None, None] + [gradi]
+		else:
+			return [None, None] + gradi
 	
 	def __hash__(self):
 		return hash(ScanDotEnc)^58994
@@ -186,7 +191,7 @@
 		for i in range(len(args)-3):
 			if not zcalc[i]:
 				shp = args[2+i].shape
-				z[i][0] = numpy.zeros((shp[0],shp[1]))
+				z[i][0] = numpy.zeros(shp)
 		
 	def __hash__(self):
 		return hash(ScanDotEncGrad)^15684
@@ -255,7 +260,11 @@
 		z[0] = numpy.concatenate(z[0],1)
 		
 	def grad(self, args, gz):
-		return [None, None] + ScanDotDecGrad()(args,gz)
+		gradi = ScanDotDecGrad()(args,gz)
+		if type(gradi) != list:
+			return [None, None] + [gradi]
+		else:
+			return [None, None] + gradi
 	
 	def __hash__(self):
 		return hash(ScanDotDec)^73568
@@ -337,7 +346,7 @@
 		for i in range((len(args)-4)):
 			if not zcalc[i]:
 				shp = args[3+i].shape
-				z[i+1][0] = numpy.zeros((shp[0],shp[1]))
+				z[i+1][0] = numpy.zeros(shp[0])
 		
 		
 	def __hash__(self):
@@ -450,7 +459,11 @@
 		return hash(ScanBiasDec)^60056
 	
 	def grad(self,args,gz):
-		return [None,None] + ScanBiasDecGrad()(args,gz)
+		gradi = ScanBiasDecGrad()(args,gz)
+		if type(gradi) != list:
+			return [None, None] + [gradi]
+		else:
+			return [None, None] + gradi
 	
 	def __str__(self):
 		return "ScanBiasDec"