comparison pylearn/sandbox/scan_inputs_groups.py @ 768:bd95e8ea99d8

small change to Checks of scan_inputs_groups ops
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Mon, 08 Jun 2009 13:14:09 -0400
parents 1e97e7c7f11f
children 742972b6906a
comparison
equal deleted inserted replaced
767:1e97e7c7f11f 768:bd95e8ea99d8
102 n_hid = (args[2].shape)[1] 102 n_hid = (args[2].shape)[1]
103 if len(idx_list) != len(args[1]) : 103 if len(idx_list) != len(args[1]) :
104 raise NotImplementedError('size of index different of inputs list size',idx_list) 104 raise NotImplementedError('size of index different of inputs list size',idx_list)
105 if max(idx_list) >= (len(args)-2)+1 : 105 if max(idx_list) >= (len(args)-2)+1 :
106 raise NotImplementedError('index superior to weight list length',idx_list) 106 raise NotImplementedError('index superior to weight list length',idx_list)
107 for i in range(len(args[1])): 107 for a in args[1]:
108 if (args[1][i].shape)[0] != batchsize: 108 if (a.shape)[0] != batchsize:
109 raise NotImplementedError('different batchsize in the inputs list',args[1][i].shape) 109 raise NotImplementedError('different batchsize in the inputs list',a.shape)
110 for i in range(len(args)-2): 110 for a in args[2:]:
111 if (args[2+i].shape)[1] != n_hid: 111 if (a.shape)[1] != n_hid:
112 raise NotImplementedError('different length of hidden in the weights list',args[2+i].shape) 112 raise NotImplementedError('different length of hidden in the weights list',a.shape)
113 113
114 for i in range(len(idx_list)): 114 for i in range(len(idx_list)):
115 if idx_list[i]>0: 115 if idx_list[i]>0:
116 if hidcalc: 116 if hidcalc:
117 self.m.dotin(args[1][i],args[2+int(idx_list[i]-1)]) 117 self.m.dotin(args[1][i],args[2+int(idx_list[i]-1)])
169 n_hid = (args[2].shape)[1] 169 n_hid = (args[2].shape)[1]
170 if len(idx_list) != len(args[1]) : 170 if len(idx_list) != len(args[1]) :
171 raise NotImplementedError('size of index different of inputs list size',idx_list) 171 raise NotImplementedError('size of index different of inputs list size',idx_list)
172 if max(idx_list) >= (len(args)-3)+1 : 172 if max(idx_list) >= (len(args)-3)+1 :
173 raise NotImplementedError('index superior to weight list length',idx_list) 173 raise NotImplementedError('index superior to weight list length',idx_list)
174 for i in range(len(args[1])): 174 for a in args[1]:
175 if (args[1][i].shape)[0] != batchsize: 175 if (a.shape)[0] != batchsize:
176 raise NotImplementedError('different batchsize in the inputs list',args[1][i].shape) 176 raise NotImplementedError('different batchsize in the inputs list',a.shape)
177 for i in range(len(args)-3): 177 for a in args[2:-1]:
178 if (args[2+i].shape)[1] != n_hid: 178 if (a.shape)[1] != n_hid:
179 raise NotImplementedError('different length of hidden in the weights list',args[2+i].shape) 179 raise NotImplementedError('different length of hidden in the weights list',a.shape)
180 180
181 zcalc = [False for i in range(len(args)-3)] 181 zcalc = [False for i in range(len(args)-3)]
182 182
183 for i in range(len(idx_list)): 183 for i in range(len(idx_list)):
184 if idx_list[i]>0: 184 if idx_list[i]>0:
235 n_hid = self.m.hid.shape[1] 235 n_hid = self.m.hid.shape[1]
236 if max(idx_list) >= len(args)-3+1 : 236 if max(idx_list) >= len(args)-3+1 :
237 raise NotImplementedError('index superior to weight list length',idx_list) 237 raise NotImplementedError('index superior to weight list length',idx_list)
238 if len(idx_list) != len(args[1]) : 238 if len(idx_list) != len(args[1]) :
239 raise NotImplementedError('size of index different of inputs list size',idx_list) 239 raise NotImplementedError('size of index different of inputs list size',idx_list)
240 for i in range(len(args)-3): 240 for a in args[3:]:
241 if (args[3+i].shape)[0] != n_hid: 241 if (a.shape)[0] != n_hid:
242 raise NotImplementedError('different length of hidden in the weights list',args[3+i].shape) 242 raise NotImplementedError('different length of hidden in the weights list',a.shape)
243 243
244 zcalc = [False for i in idx_list] 244 zcalc = [False for i in idx_list]
245 z[0] = [None for i in idx_list] 245 z[0] = [None for i in idx_list]
246 246
247 for i in range(len(idx_list)): 247 for i in range(len(idx_list)):
309 n_hid = self.m.hidt.shape[0] 309 n_hid = self.m.hidt.shape[0]
310 if max(idx_list) >= len(args)-4+1 : 310 if max(idx_list) >= len(args)-4+1 :
311 raise NotImplementedError('index superior to weight list length',idx_list) 311 raise NotImplementedError('index superior to weight list length',idx_list)
312 if len(idx_list) != len(args[1]) : 312 if len(idx_list) != len(args[1]) :
313 raise NotImplementedError('size of index different of inputs list size',idx_list) 313 raise NotImplementedError('size of index different of inputs list size',idx_list)
314 for a in args[3:]: 314 for a in args[3:-1]:
315 if a.shape[0] != n_hid: 315 if a.shape[0] != n_hid:
316 raise NotImplementedError('different length of hidden in the weights list',args[3+i].shape) 316 raise NotImplementedError('different length of hidden in the weights list',a.shape)
317 317
318 zidx=numpy.zeros((len(idx_list)+1)) 318 zidx=numpy.zeros((len(idx_list)+1))
319 319
320 for i in range(len(idx_list)): 320 for i in range(len(idx_list)):
321 if idx_list[i] == 0: 321 if idx_list[i] == 0:
536 dim = 0 536 dim = 0
537 n_hid = args[1].shape[dim] 537 n_hid = args[1].shape[dim]
538 538
539 if max(idx_list) >= (len(args)-1)+1 : 539 if max(idx_list) >= (len(args)-1)+1 :
540 raise NotImplementedError('index superior to weights list length',idx_listdec) 540 raise NotImplementedError('index superior to weights list length',idx_listdec)
541 for i in range(len(args)-1): 541 for a in args[1:]:
542 if args[1+i].shape[dim] != n_hid: 542 if a.shape[dim] != n_hid:
543 raise NotImplementedError('different length of hidden in the encoding weights list',args[1+i].shape) 543 raise NotImplementedError('different length of hidden in the encoding weights list',a.shape)
544 544
545 for i in range(len(args[1:])): 545 for i in range(len(args[1:])):
546 z[i][0] = numpy.asarray((idx_list == i+1).sum(),dtype='int32') 546 z[i][0] = numpy.asarray((idx_list == i+1).sum(),dtype='int32')
547 547
548 def __hash__(self): 548 def __hash__(self):