diff pylearn/sandbox/scan_inputs_groups.py @ 768:bd95e8ea99d8

small change to Checks of scan_inputs_groups ops
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Mon, 08 Jun 2009 13:14:09 -0400
parents 1e97e7c7f11f
children 742972b6906a
line wrap: on
line diff
--- a/pylearn/sandbox/scan_inputs_groups.py	Mon Jun 08 12:14:14 2009 -0400
+++ b/pylearn/sandbox/scan_inputs_groups.py	Mon Jun 08 13:14:09 2009 -0400
@@ -104,12 +104,12 @@
             raise NotImplementedError('size of index different of inputs list size',idx_list)
         if max(idx_list) >= (len(args)-2)+1 :
             raise NotImplementedError('index superior to weight list length',idx_list)
-        for i in range(len(args[1])):
-            if (args[1][i].shape)[0] != batchsize:
-                raise NotImplementedError('different batchsize in the inputs list',args[1][i].shape)
-        for i in range(len(args)-2):
-            if (args[2+i].shape)[1] != n_hid:
-                raise NotImplementedError('different length of hidden in the weights list',args[2+i].shape)
+        for a in args[1]:
+            if (a.shape)[0] != batchsize:
+                raise NotImplementedError('different batchsize in the inputs list',a.shape)
+        for a in args[2:]:
+            if (a.shape)[1] != n_hid:
+                raise NotImplementedError('different length of hidden in the weights list',a.shape)
     
         for i in range(len(idx_list)):
             if idx_list[i]>0:
@@ -171,12 +171,12 @@
             raise NotImplementedError('size of index different of inputs list size',idx_list)
         if max(idx_list) >= (len(args)-3)+1 :
             raise NotImplementedError('index superior to weight list length',idx_list)
-        for i in range(len(args[1])):
-            if (args[1][i].shape)[0] != batchsize:
-                raise NotImplementedError('different batchsize in the inputs list',args[1][i].shape)
-        for i in range(len(args)-3):
-            if (args[2+i].shape)[1] != n_hid:
-                raise NotImplementedError('different length of hidden in the weights list',args[2+i].shape)
+        for a in args[1]:
+            if (a.shape)[0] != batchsize:
+                raise NotImplementedError('different batchsize in the inputs list',a.shape)
+        for a in args[2:-1]:
+            if (a.shape)[1] != n_hid:
+                raise NotImplementedError('different length of hidden in the weights list',a.shape)
     
         zcalc = [False for i in range(len(args)-3)]
     
@@ -237,9 +237,9 @@
             raise NotImplementedError('index superior to weight list length',idx_list)
         if len(idx_list) != len(args[1]) :
             raise NotImplementedError('size of index different of inputs list size',idx_list)
-        for i in range(len(args)-3):
-            if (args[3+i].shape)[0] != n_hid:
-                raise NotImplementedError('different length of hidden in the weights list',args[3+i].shape)
+        for a in args[3:]:
+            if (a.shape)[0] != n_hid:
+                raise NotImplementedError('different length of hidden in the weights list',a.shape)
     
         zcalc = [False for i in idx_list]
         z[0] = [None for i in idx_list]
@@ -311,9 +311,9 @@
             raise NotImplementedError('index superior to weight list length',idx_list)
         if len(idx_list) != len(args[1]) :
             raise NotImplementedError('size of index different of inputs list size',idx_list)
-        for a in args[3:]:
+        for a in args[3:-1]:
             if a.shape[0] != n_hid:
-                raise NotImplementedError('different length of hidden in the weights list',args[3+i].shape)
+                raise NotImplementedError('different length of hidden in the weights list',a.shape)
     
         zidx=numpy.zeros((len(idx_list)+1))
     
@@ -538,9 +538,9 @@
 
         if max(idx_list) >= (len(args)-1)+1 :
             raise NotImplementedError('index superior to weights list length',idx_listdec)
-        for i in range(len(args)-1):
-            if args[1+i].shape[dim] != n_hid:
-                raise NotImplementedError('different length of hidden in the encoding weights list',args[1+i].shape)
+        for a in args[1:]:
+            if a.shape[dim] != n_hid:
+                raise NotImplementedError('different length of hidden in the encoding weights list',a.shape)
     
         for i in range(len(args[1:])):
             z[i][0] = numpy.asarray((idx_list == i+1).sum(),dtype='int32')