annotate pylearn/sandbox/scan_inputs_groups.py @ 768:bd95e8ea99d8

small change to Checks of scan_inputs_groups ops
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Mon, 08 Jun 2009 13:14:09 -0400
parents 1e97e7c7f11f
children 742972b6906a
rev   line source
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
1 import numpy
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
2 import theano
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
3 from theano import tensor as T
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
4 from theano.gof import Op
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
5 from theano.gof import Apply
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
6 from theano import scalar as scal
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
7
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
8 # These Ops allows us to deal with static groups of possibly missing inputs efficiently in the dense DAA framework
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
9 # (for exemple with multimodal data with sometimes entire modality missing).
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
10 # The inputs will be represented with an index list and a theano.generic variable (which will be a list of matrices
711
0eae6d5315b5 Fixed minor typo in comment
Olivier Delalleau <delallea@iro>
parents: 701
diff changeset
11 # (numpy array), each element will correspond to an available modality and the index list will indicate the weights
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
12 # associated to it).
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
13 # Exemple of index list: [1, 0, -3]
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
14 # *the 1 says that the first element of the input list will refer to the first element of the weights_list
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
15 # (auxiliary target as input)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
16 # if inputslist[i]>0 it refers to Weightslist[indexlist[i]-1]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
17 # *the 0 means that the second element of the input list will not be encoded neither decoded (it is remplaced by zeros)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
18 # this is not efficient, so in this case it is better to give: [1,-3] and [inputslist[0],inputslist[2]]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
19 # but it allows us to deal with empty lists: give indexlist = numpy.asarray([.0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
20 # and inputlist=numpy.zeros((batchsize,1))
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
21 # *when an index is negative it means that the input will not be used for encoding but we will still reconstruct it
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
22 # (auxiliary target as output)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
23 # if inputslist[i]<0 it refers to Weightslist[-indexlist[i]-1]
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
24 #
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
25 # An entire batch should have the same available inputs configuration.
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
26 #
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
27 # Dense DAA Exemple:----------------------------------------------------------------------------
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
28 #
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
29 #from theano.tensor.nnet import sigmoid
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
30 #
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
31 #nb_modality = 4
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
32 #wenc = [T.dmatrix('wenc%s'%i) for i in range(nb_modality)]
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
33 #wdec = [T.dmatrix('wdec%s'%i) for i in range(nb_modality)]
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
34 #benc = T.dvector('benc')
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
35 #bdec = [T.dvector('bdec%s'%i) for i in range(nb_modality)]
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
36 #vectin = T.ivector('vectin')
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
37 #inputpart = theano.generic('inputpart')
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
38 #noise_bit = T.dscalar('noise_bit')
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
39 #noise_group = T.dscalar('noise_group')
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
40 #
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
41 #[vectin2,inputpart2] = scannoise(vectin,inputpart,noise_bit,noise_group)
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
42 #hid = scandotenc(vectin2, inputpart2, wenc)
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
43 #acthid = sigmoid(hid + benc)
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
44 #dec = sigmoid(scanbiasdec(vectin2,inputpart2,bdec) + scandotdec(vectin2, inputpart2,acthid,wdec))
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
45 #cost = T.sum(T.sum(T.sqr( scaninput(vectin,inputpart) - rec ),1),0)
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
46
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
47 # Checking inputs in make_node methods----------------------
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
48 def Checkidx_list(idx_list):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
49 idx_list = T.as_tensor_variable(idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
50 nidx = idx_list.type.ndim
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
51 if nidx != 1: raise TypeError('not vector', idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
52 return idx_list
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
53
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
54 def Checkhidd(hidd):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
55 hidd = T.as_tensor_variable(hidd)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
56 nhidd = hidd.type.ndim
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
57 if nhidd not in (1,2): raise TypeError('not matrix or vector', hidd)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
58 return hidd
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
59
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
60 def Checkweights_list(weights_list):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
61 weights_list = map(T.as_tensor_variable, weights_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
62 for i in range(len(weights_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
63 nweights = weights_list[i].type.ndim
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
64 if nweights not in (1,2): raise TypeError('not matrix or vector', weights_list[i])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
65 return weights_list
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
66
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
67 def Checkbias_list(bias_list):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
68 bias_list = map(T.as_tensor_variable, bias_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
69 for i in range(len(bias_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
70 nbias = bias_list[i].type.ndim
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
71 if nbias != 1: raise TypeError('not vector', bias_list[i])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
72 return bias_list
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
73
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
74 # Encoding scan dot product------------------------------------
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
75 class ScanDotEnc(Op):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
76 """This Op takes an index list (as tensor.ivector), a list of matrices representing
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
77 the available inputs (as theano.generic), and all the encoding weights tensor.dmatrix of the model. It will select the
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
78 weights corresponding to the inputs (according to index list) and compute only the necessary dot products"""
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
79 def __init__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
80 #Create Theano methods to do the dot products with blas or at least in C.
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
81 self.M=theano.Module()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
82 inputs = T.dmatrix('input')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
83 weights = T.dmatrix('weights')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
84 self.M.hid = T.dmatrix('hid')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
85 self.M.resultin = self.M.hid + T.dot(inputs,weights)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
86 result = T.dot(inputs,weights)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
87
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
88 self.M.dotin = theano.Method([inputs,weights],None,{self.M.hid : self.M.resultin})
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
89 self.M.dot = theano.Method([inputs,weights],result)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
90 self.m = self.M.make()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
91
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
92 def make_node(self, idx_list, inputs_list, weights_list):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
93 idx_list = Checkidx_list(idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
94 weights_list = Checkweights_list(weights_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
95 return Apply(self, [idx_list] + [inputs_list] + weights_list, [T.dmatrix()])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
96
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
97 def perform(self, node, args, (hid,)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
98 idx_list = args[0]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
99 hidcalc = False
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
100
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
101 batchsize = (args[1][0].shape)[0]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
102 n_hid = (args[2].shape)[1]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
103 if len(idx_list) != len(args[1]) :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
104 raise NotImplementedError('size of index different of inputs list size',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
105 if max(idx_list) >= (len(args)-2)+1 :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
106 raise NotImplementedError('index superior to weight list length',idx_list)
768
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
107 for a in args[1]:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
108 if (a.shape)[0] != batchsize:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
109 raise NotImplementedError('different batchsize in the inputs list',a.shape)
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
110 for a in args[2:]:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
111 if (a.shape)[1] != n_hid:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
112 raise NotImplementedError('different length of hidden in the weights list',a.shape)
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
113
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
114 for i in range(len(idx_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
115 if idx_list[i]>0:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
116 if hidcalc:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
117 self.m.dotin(args[1][i],args[2+int(idx_list[i]-1)])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
118 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
119 self.m.hid = self.m.dot(args[1][i],args[2+int(idx_list[i]-1)])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
120 hidcalc = True
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
121
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
122 if not hidcalc:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
123 hid[0] = numpy.zeros([batchsize,n_hid])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
124 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
125 hid[0] = self.m.hid
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
126
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
127
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
128 def grad(self, args, gz):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
129 gradi = ScanDotEncGrad()(args,gz)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
130 if type(gradi) != list:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
131 return [None, None] + [gradi]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
132 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
133 return [None, None] + gradi
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
134
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
135 def __hash__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
136 return hash(ScanDotEnc)^58994
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
137
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
138 def __str__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
139 return "ScanDotEnc"
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
140
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
141 scandotenc=ScanDotEnc()
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
142
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
143 class ScanDotEncGrad(Op):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
144 """This Op computes the gradient wrt the weights for ScanDotEnc"""
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
145 def __init__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
146 #Create Theano methods to do the dot products with blas or at least in C.
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
147 self.M=theano.Module()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
148 input1 = T.dmatrix('input1')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
149 self.M.g_out = T.dmatrix('g_out')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
150 result = T.dmatrix('result')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
151 input2=T.transpose(input1)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
152 self.M.resultin = result + T.dot(input2,self.M.g_out)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
153 self.M.result = T.dot(input2,self.M.g_out)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
154
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
155 self.M.dotin = theano.Method([input1,result],self.M.resultin)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
156 self.M.dot = theano.Method([input1],self.M.result)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
157 self.m = self.M.make()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
158
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
159 def make_node(self, args, g_out):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
160 idx_list = Checkidx_list(args[0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
161 weights_list = Checkweights_list(args[2:])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
162 return Apply(self, args + g_out, [T.dmatrix() for i in xrange(2,len(args))])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
163
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
164 def perform(self, node, args, z):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
165 idx_list = args[0]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
166 self.m.g_out = args[-1]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
167
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
168 batchsize = (args[1][0].shape)[0]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
169 n_hid = (args[2].shape)[1]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
170 if len(idx_list) != len(args[1]) :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
171 raise NotImplementedError('size of index different of inputs list size',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
172 if max(idx_list) >= (len(args)-3)+1 :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
173 raise NotImplementedError('index superior to weight list length',idx_list)
768
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
174 for a in args[1]:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
175 if (a.shape)[0] != batchsize:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
176 raise NotImplementedError('different batchsize in the inputs list',a.shape)
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
177 for a in args[2:-1]:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
178 if (a.shape)[1] != n_hid:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
179 raise NotImplementedError('different length of hidden in the weights list',a.shape)
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
180
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
181 zcalc = [False for i in range(len(args)-3)]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
182
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
183 for i in range(len(idx_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
184 if idx_list[i]>0:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
185 if zcalc[int(idx_list[i]-1)]:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
186 z[int(idx_list[i]-1)][0] = self.m.dotin(args[1][i],z[int(idx_list[i]-1)][0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
187 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
188 z[int(idx_list[i]-1)][0] = self.m.dot(args[1][i])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
189 zcalc[int(idx_list[i]-1)] = True
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
190
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
191 for i in range(len(args)-3):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
192 if not zcalc[i]:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
193 shp = args[2+i].shape
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
194 z[i][0] = numpy.zeros(shp)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
195
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
196 def __hash__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
197 return hash(ScanDotEncGrad)^15684
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
198
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
199 def __str__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
200 return "ScanDotEncGrad"
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
201
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
202 # Decoding scan dot product------------------------------------
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
203 class ScanDotDec(Op):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
204 """This Op takes an index list (as tensor.ivector), a list of matrices representing
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
205 the available inputs (as theano.generic), the hidden layer of the DAA (theano.dmatrix)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
206 and all the decoding weights tensor.dmatrix of the model. It will select the
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
207 weights corresponding to the available inputs (according to index list) and compute
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
208 only the necessary dot products. The outputs will be concatenated and will represent
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
209 the reconstruction of the different modality in the same order than the index list"""
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
210 def __init__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
211 #Create Theano methods to do the dot products with blas or at least in C.
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
212 self.M=theano.Module()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
213 weights = T.dmatrix('weights')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
214 self.M.hid = T.dmatrix('hid')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
215 oldval = T.dmatrix('oldval')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
216 resultin = oldval + T.dot(self.M.hid,weights)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
217 result = T.dot(self.M.hid,weights)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
218
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
219 self.M.dotin = theano.Method([weights,oldval],resultin)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
220 self.M.dot = theano.Method([weights],result)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
221 self.m = self.M.make()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
222
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
223 def make_node(self, idx_list, input_list, hidd, weights_list):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
224 idx_list = Checkidx_list(idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
225 hidd = Checkhidd(hidd)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
226 weights_list = Checkweights_list(weights_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
227 return Apply(self, [idx_list] + [input_list] +[hidd] + weights_list,[T.dmatrix()])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
228
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
229 def perform(self, node, args, (z,)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
230
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
231 idx_list = abs(args[0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
232 self.m.hid = args[2]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
233
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
234 batchsize = (self.m.hid.shape)[0]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
235 n_hid = self.m.hid.shape[1]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
236 if max(idx_list) >= len(args)-3+1 :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
237 raise NotImplementedError('index superior to weight list length',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
238 if len(idx_list) != len(args[1]) :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
239 raise NotImplementedError('size of index different of inputs list size',idx_list)
768
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
240 for a in args[3:]:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
241 if (a.shape)[0] != n_hid:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
242 raise NotImplementedError('different length of hidden in the weights list',a.shape)
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
243
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
244 zcalc = [False for i in idx_list]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
245 z[0] = [None for i in idx_list]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
246
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
247 for i in range(len(idx_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
248 if idx_list[i]>0:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
249 if zcalc[i]:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
250 z[0][i] = self.m.dotin(args[3+int(idx_list[i]-1)],z[0][i])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
251 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
252 z[0][i] = self.m.dot(args[3+int(idx_list[i]-1)])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
253 zcalc[i] = True
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
254
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
255 for i in range(len(idx_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
256 if not zcalc[i]:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
257 shp = args[1][int(idx_list[i]-1)].shape
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
258 z[0][i] = numpy.zeros((batchsize,shp[1]))
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
259
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
260 z[0] = numpy.concatenate(z[0],1)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
261
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
262 def grad(self, args, gz):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
263 gradi = ScanDotDecGrad()(args,gz)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
264 if type(gradi) != list:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
265 return [None, None] + [gradi]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
266 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
267 return [None, None] + gradi
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
268
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
269 def __hash__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
270 return hash(ScanDotDec)^73568
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
271
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
272 def __str__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
273 return "ScanDotDec"
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
274
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
275 scandotdec=ScanDotDec()
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
276
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
277 class ScanDotDecGrad(Op):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
278 """This Op computes the gradient wrt the weights for ScanDotDec"""
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
279 def __init__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
280 self.M=theano.Module()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
281 gout = T.dmatrix('gout')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
282 self.M.hidt = T.dmatrix('hid')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
283 oldval = T.dmatrix('oldval')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
284 resultin1 = oldval + T.dot(self.M.hidt,gout)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
285 result1 = T.dot(self.M.hidt,gout)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
286 weights = T.dmatrix('weights')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
287 weights2 = T.transpose(weights)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
288 resultin2 = oldval + T.dot(gout,weights2)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
289 result2 = T.dot(gout,weights2)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
290
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
291 self.M.dotin1 = theano.Method([gout,oldval],resultin1)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
292 self.M.dot1 = theano.Method([gout],result1)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
293 self.M.dotin2 = theano.Method([gout,weights,oldval],resultin2)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
294 self.M.dot2 = theano.Method([gout,weights],result2)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
295 self.m = self.M.make()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
296
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
297
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
298 def make_node(self, args, g_out):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
299 idx_list = Checkidx_list(args[0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
300 hidd = Checkhidd(args[2])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
301 weights_list = Checkweights_list(args[3:])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
302 return Apply(self, args + g_out, [T.dmatrix() for i in xrange(2,len(args))])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
303
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
304 def perform(self, node, args, z):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
305 idx_list = abs(args[0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
306 self.m.hidt = args[2].T
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
307
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
308 batchsize = (self.m.hidt.shape)[1]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
309 n_hid = self.m.hidt.shape[0]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
310 if max(idx_list) >= len(args)-4+1 :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
311 raise NotImplementedError('index superior to weight list length',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
312 if len(idx_list) != len(args[1]) :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
313 raise NotImplementedError('size of index different of inputs list size',idx_list)
768
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
314 for a in args[3:-1]:
767
1e97e7c7f11f very small opt.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents: 745
diff changeset
315 if a.shape[0] != n_hid:
768
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
316 raise NotImplementedError('different length of hidden in the weights list',a.shape)
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
317
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
318 zidx=numpy.zeros((len(idx_list)+1))
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
319
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
320 for i in range(len(idx_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
321 if idx_list[i] == 0:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
322 zidx[i+1] = (args[1][i].shape)[1]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
323 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
324 zidx[i+1] = (args[3+idx_list[i]-1].shape)[1]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
325
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
326 zidx=zidx.cumsum()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
327 hidcalc = False
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
328 zcalc = [False for i in range((len(args)-4))]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
329
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
330 for i in range(len(idx_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
331 if idx_list[i]>0:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
332 if zcalc[int(idx_list[i])-1]:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
333 z[int(idx_list[i])][0] = self.m.dotin1(args[-1][:,zidx[i]:zidx[i+1]],z[int(idx_list[i])][0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
334 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
335 z[int(idx_list[i])][0] = self.m.dot1(args[-1][:,zidx[i]:zidx[i+1]])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
336 zcalc[int(idx_list[i])-1] = True
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
337 if hidcalc:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
338 z[0][0] = self.m.dotin2(args[-1][:,zidx[i]:zidx[i+1]],args[3+int(idx_list[i]-1)],z[0][0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
339 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
340 z[0][0] = self.m.dot2(args[-1][:,zidx[i]:zidx[i+1]],args[3+int(idx_list[i]-1)])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
341 hidcalc = True
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
342
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
343 if not hidcalc:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
344 z[0][0] = numpy.zeros((self.m.hidt.shape[1],self.m.hidt.shape[0]))
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
345
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
346 for i in range((len(args)-4)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
347 if not zcalc[i]:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
348 shp = args[3+i].shape
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
349 z[i+1][0] = numpy.zeros(shp)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
350
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
351
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
352 def __hash__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
353 return hash(ScanDotDecGrad)^87445
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
354
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
355 def __str__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
356 return "ScanDotDecGrad"
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
357
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
358 # DAA input noise------------------------------------
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
359 class ScanNoise(Op):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
360 """This Op takes an index list (as tensor.ivector), a list of matrices representing
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
361 the available inputs (as theano.generic), a probability of individual bit masking and
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
362 a probability of modality masking. It will return the inputs list with randoms zeros entry
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
363 and the index list with some positive values changed to negative values (groups masking)"""
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
364 def __init__(self, seed = 1):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
365 self.M=theano.Module()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
366 self.M.rand = T.RandomStreams(seed)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
367 self.seed = seed
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
368 mat = T.matrix('mat')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
369 noise_level_bit = T.dscalar('noise_level_bit')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
370 noise_level_group = T.dscalar('noise_level_group')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
371 self.M.out1 = self.M.rand.binomial(T.shape(mat), 1, 1 - noise_level_bit) * mat
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
372 self.M.out2 = self.M.rand.binomial((1,1), 1, 1 - noise_level_group)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
373
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
374 self.M.noisify_bit = theano.Method([mat,noise_level_bit],self.M.out1)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
375 self.M.noisify_group_bool = theano.Method([noise_level_group],self.M.out2)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
376 self.R = self.M.make()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
377 self.R.rand.initialize()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
378
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
379 def make_node(self, idx_list, inputs_list, noise_level_bit, noise_level_group):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
380 idx_list = Checkidx_list(idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
381 return Apply(self, [idx_list] + [inputs_list] + [noise_level_bit] + [noise_level_group],\
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
382 [T.ivector(), theano.generic()])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
383
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
384 def perform(self, node, (idx_list,inputs_list,noise_level_bit,noise_level_group), (y,z)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
385
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
386 if len(idx_list) != len(inputs_list) :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
387 raise NotImplementedError('size of index different of inputs list size',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
388
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
389 y[0] = numpy.asarray([-i if (i>0 and not(self.R.noisify_group_bool(noise_level_group))) else i for i in idx_list])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
390 z[0] = [(self.R.noisify_bit(inputs_list[i],noise_level_bit) if y[0][i]>0 else numpy.zeros((inputs_list[i].shape)))\
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
391 for i in range(len(inputs_list))]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
392
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
393 def grad(self,args,gz):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
394 return [None,None,None,None]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
395
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
396
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
397 def __hash__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
398 return hash(ScanNoise)^hash(self.seed)^hash(self.R.rand)^12254
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
399
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
400 def __str__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
401 return "ScanNoise"
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
402
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
403 scannoise=ScanNoise()
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
404
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
405 # Total input matrix construction------------------------------------
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
406 class ScanInputs(Op):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
407 """This Op takes an index list (as tensor.ivector) and a list of matrices representing
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
408 the available inputs (as theano.generic). It will construct the appropriate tensor.dmatrix
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
409 to compare to the reconstruction obtained with ScanDotDec"""
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
410 def make_node(self, idx_list, inputs_list):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
411 idx_list = Checkidx_list(idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
412 return Apply(self, [idx_list] + [inputs_list],[T.dmatrix()])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
413
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
414 def perform(self, node, (idx_list, inputs_list), (z,)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
415
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
416 if len(idx_list) != len(inputs_list):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
417 raise NotImplementedError('size of index different of inputs list size',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
418
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
419 for i in range(len(idx_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
420 if idx_list[i] == 0:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
421 inputs_list[i] = 0 * inputs_list[i]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
422
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
423 z[0] = numpy.concatenate(inputs_list,1)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
424
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
425 def grad(self,args,gz):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
426 return [None,None]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
427
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
428 def __hash__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
429 return hash(ScanInputs)^75902
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
430
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
431 def __str__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
432 return "ScanInputs"
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
433
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
434 scaninputs=ScanInputs()
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
435
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
436 # Decoding bias vector construction------------------------------------
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
437 class ScanBiasDec(Op):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
438 """This Op takes an index list (as tensor.ivector), a list of matrices representing
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
439 the available inputs (as theano.generic) and the decoding bias tensor.dvector.
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
440 It will construct the appropriate bias tensor.dvector
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
441 to add to the reconstruction obtained with ScanDotDec"""
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
442 def make_node(self, idx_list, input_list, bias_list):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
443 idx_list = Checkidx_list(idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
444 bias_list = Checkbias_list(bias_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
445 return Apply(self, [idx_list] + [input_list] + bias_list, [T.dvector()])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
446
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
447 def perform(self, node, args, (z,)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
448 idx_list = abs(args[0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
449
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
450 if max(idx_list) >= (len(args)-2)+1 :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
451 raise NotImplementedError('index superior to bias list length',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
452 if len(idx_list) != len(args[1]) :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
453 raise NotImplementedError('size of index different of inputs list size',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
454 z[0] = [args[idx_list[i]+1] if idx_list[i] != 0 else numpy.zeros(args[1][i].shape[1]) \
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
455 for i in range(len(idx_list))]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
456 z[0] = numpy.concatenate(z[0],1)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
457
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
458 def __hash__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
459 return hash(ScanBiasDec)^60056
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
460
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
461 def grad(self,args,gz):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
462 gradi = ScanBiasDecGrad()(args,gz)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
463 if type(gradi) != list:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
464 return [None, None] + [gradi]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
465 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
466 return [None, None] + gradi
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
467
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
468 def __str__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
469 return "ScanBiasDec"
686
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
470
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
471 scanbiasdec=ScanBiasDec()
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
472
0457dfa6fcad add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff changeset
473 class ScanBiasDecGrad(Op):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
474 """This Op computes the gradient wrt the bias for ScanBiasDec"""
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
475 def make_node(self, args, g_out):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
476 idx_list = Checkidx_list(args[0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
477 bias_list = Checkbias_list(args[2:])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
478 return Apply(self, args + g_out, [T.dvector() for i in range(len(args)-2)])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
479
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
480 def perform(self, node, args, z):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
481 idx_list = abs(args[0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
482
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
483 if max(idx_list) >= (len(args)-3)+1 :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
484 raise NotImplementedError('index superior to bias list length',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
485 if len(idx_list) != len(args[1]) :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
486 raise NotImplementedError('size of index different of inputs list size',idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
487
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
488 zidx=numpy.zeros((len(idx_list)+1))
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
489 for i in range(len(idx_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
490 if idx_list[i] == 0:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
491 zidx[i+1] = (args[1][i].shape)[1]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
492 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
493 zidx[i+1] = (args[2+idx_list[i]-1].size)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
494 zidx=zidx.cumsum()
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
495 zcalc = [False for i in range((len(args)-3))]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
496
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
497 for i in range(len(idx_list)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
498 if idx_list[i]>0:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
499 if zcalc[int(idx_list[i])-1]:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
500 z[int(idx_list[i])-1][0] += args[-1][zidx[i]:zidx[i+1]]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
501 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
502 z[int(idx_list[i])-1][0] = args[-1][zidx[i]:zidx[i+1]]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
503 zcalc[int(idx_list[i])-1] = True
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
504
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
505 for i in range((len(args)-3)):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
506 if not zcalc[i]:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
507 shp = args[2+i].size
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
508 z[i][0] = numpy.zeros(shp)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
509
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
510
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
511 def __hash__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
512 return hash(ScanBiasDecGrad)^41256
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
513
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
514 def __str__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
515 return "ScanBiasDecGrad"
694
69947f4e9c0e added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 686
diff changeset
516
69947f4e9c0e added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 686
diff changeset
517 # Mask construction------------------------------------
69947f4e9c0e added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 686
diff changeset
518 class ScanMask(Op):
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
519 """This Op takes an index list (as tensor.ivector) and a list of weigths.
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
520 It will construct a list of T.iscalar representing the Mask
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
521 to do the correct regularisation on the weigths"""
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
522 def __init__(self,encbool=True):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
523 self.encbool = encbool
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
524
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
525 def make_node(self, idx_list, weights_list):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
526 idx_list = Checkidx_list(idx_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
527 weights_list = Checkweights_list(weights_list)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
528 return Apply(self, [idx_list] + weights_list, [T.iscalar() for i in range(len(weights_list))])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
529
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
530 def perform(self, node, args, z):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
531 if self.encbool:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
532 idx_list = args[0]
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
533 dim = 1
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
534 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
535 idx_list = abs(args[0])
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
536 dim = 0
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
537 n_hid = args[1].shape[dim]
701
113946723973 fixed bug of scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 700
diff changeset
538
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
539 if max(idx_list) >= (len(args)-1)+1 :
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
540 raise NotImplementedError('index superior to weights list length',idx_listdec)
768
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
541 for a in args[1:]:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
542 if a.shape[dim] != n_hid:
bd95e8ea99d8 small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 767
diff changeset
543 raise NotImplementedError('different length of hidden in the encoding weights list',a.shape)
714
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
544
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
545 for i in range(len(args[1:])):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
546 z[i][0] = numpy.asarray((idx_list == i+1).sum(),dtype='int32')
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
547
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
548 def __hash__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
549 return hash(ScanMask)^hash(self.encbool)^11447
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
550
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
551 def grad(self,args,gz):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
552 return [None] * len(args)
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
553
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
554 def __str__(self):
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
555 if self.encbool:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
556 string = "Enc"
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
557 else:
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
558 string = "Dec"
8d5d42274bd1 improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 711
diff changeset
559 return "ScanMask" + string
694
69947f4e9c0e added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 686
diff changeset
560
69947f4e9c0e added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents: 686
diff changeset
561 scanmaskenc=ScanMask(True)
711
0eae6d5315b5 Fixed minor typo in comment
Olivier Delalleau <delallea@iro>
parents: 701
diff changeset
562 scanmaskdec=ScanMask(False)
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
563
717
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
564 # TODO The classes FillMissing and MaskSelect below should probably be moved
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
565 # to another (more appropriate) file.
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
566 class FillMissing(Op):
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
567 """
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
568 Given an input, output two elements:
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
569 - a copy of the input where missing values (NaN) are replaced by some
745
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
570 other value (zero by default)
720
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
571 - a mask of the same size and type as input, where each element is zero
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
572 iff the corresponding input is missing
745
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
573 The 'fill_with' parameter may either be:
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
574 - a scalar: all missing values are replaced with this value
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
575 - a Numpy array: a missing value is replaced by the value in this array
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
576 at the same position (ignoring the first k dimensions if 'fill_with'
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
577 has k less dimensions than the input)
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
578 Currently, the gradient is computed as if the input value was really what
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
579 it was replaced with. It may be safer to replace the gradient w.r.t.
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
580 missing values with either zeros or missing values (?).
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
581 """
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
582
745
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
583 def __init__(self, fill_with=0):
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
584 super(Op, self).__init__()
745
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
585 self.fill_with = fill_with
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
586 self.fill_with_is_array = isinstance(self.fill_with, numpy.ndarray)
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
587
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
588 def __eq__(self, other):
745
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
589 return (type(self) == type(other) and
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
590 self.fill_with_is_array == other.fill_with_is_array and
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
591 ((self.fill_with_is_array and
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
592 (self.fill_with == other.fill_with).all()) or
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
593 self.fill_with == other.fill_with))
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
594
745
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
595 def __hash__(self):
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
596 if self.fill_with_is_array:
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
597 fill_hash = self.fill_with.__hash__()
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
598 else:
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
599 fill_hash = hash(self.fill_with)
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
600 return hash(type(self))^hash(self.fill_with_is_array)^fill_hash
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
601
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
602 def make_node(self, input):
720
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
603 return Apply(self, [input], [input.type(), input.type()])
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
604
724
d42b4bcbb582 Replaced debug special code for missing values (-123456) by truly missing (NaN)
Olivier Delalleau <delallea@iro>
parents: 720
diff changeset
605 def perform(self, node, (input, ), output_storage):
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
606 out = output_storage[0]
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
607 out[0] = input.copy()
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
608 out = out[0]
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
609 mask = output_storage[1]
720
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
610 mask[0] = numpy.ones(input.shape)
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
611 mask = mask[0]
745
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
612 if self.fill_with_is_array:
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
613 ignore_k = len(out.shape) - len(self.fill_with.shape)
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
614 assert ignore_k >= 0
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
615 for (idx, v) in numpy.ndenumerate(out):
724
d42b4bcbb582 Replaced debug special code for missing values (-123456) by truly missing (NaN)
Olivier Delalleau <delallea@iro>
parents: 720
diff changeset
616 if numpy.isnan(v):
745
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
617 if self.fill_with_is_array:
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
618 out[idx] = self.fill_with[idx[ignore_k:]]
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
619 else:
fc85ce33b518 FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents: 724
diff changeset
620 out[idx] = self.fill_with
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
621 mask[idx] = 0
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
622
720
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
623 def grad(self, inputs, (out_grad, mask_grad, )):
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
624 return [out_grad]
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
625
716
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
626 fill_missing_with_zeros = FillMissing(0)
d7c6dadb4aa9 New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents: 711
diff changeset
627
720
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
628 class MaskGradient(Op):
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
629 """
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
630 Takes as input a tensor and a mask. Outputs the same tensor, but setting
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
631 to zero the gradient for all elements where the mask's value is zero.
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
632 """
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
633
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
634 def __eq__(self, other):
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
635 return type(self) == type(other)
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
636
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
637 def __hash__(self):
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
638 return hash(type(self))
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
639
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
640 def make_node(self, input, mask):
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
641 return Apply(self, [input, mask], [input.type()])
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
642
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
643 def perform(self, node, (input, mask), (output, )):
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
644 output[0] = input.copy()
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
645
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
646 def grad(self, (input, mask), (out_grad, )):
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
647 return [out_grad * T.neq(mask, 0), None]
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
648
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
649 mask_gradient = MaskGradient()
0594cba02fa8 Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents: 719
diff changeset
650
717
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
651 class MaskSelect(Op):
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
652 """
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
653 Given an input x and a mask m (both vectors), outputs a vector that
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
654 contains all elements x[i] such that bool(m[i]) is True.
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
655 """
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
656
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
657 def __eq__(self, other):
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
658 return type(self) == type(other)
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
659
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
660 def __hash__(self):
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
661 return hash(type(self))
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
662
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
663 def make_node(self, input, mask):
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
664 return Apply(self, [input, mask], [input.type()])
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
665
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
666 def perform(self, node, (input, mask), (output, )):
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
667 select = []
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
668 for (i, m) in enumerate(mask):
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
669 if bool(m):
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
670 select.append(i)
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
671 output[0] = numpy.zeros(len(select), dtype = input.dtype)
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
672 out = output[0]
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
673 for (i, j) in enumerate(select):
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
674 out[i] = input[j]
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
675
bf29e201515f New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents: 716
diff changeset
676 mask_select = MaskSelect()