Mercurial > pylearn
annotate pylearn/sandbox/scan_inputs_groups.py @ 768:bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
author | Xavier Glorot <glorotxa@iro.umontreal.ca> |
---|---|
date | Mon, 08 Jun 2009 13:14:09 -0400 |
parents | 1e97e7c7f11f |
children | 742972b6906a |
rev | line source |
---|---|
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
1 import numpy |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
2 import theano |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
3 from theano import tensor as T |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
4 from theano.gof import Op |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
5 from theano.gof import Apply |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
6 from theano import scalar as scal |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
7 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
8 # These Ops allows us to deal with static groups of possibly missing inputs efficiently in the dense DAA framework |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
9 # (for exemple with multimodal data with sometimes entire modality missing). |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
10 # The inputs will be represented with an index list and a theano.generic variable (which will be a list of matrices |
711
0eae6d5315b5
Fixed minor typo in comment
Olivier Delalleau <delallea@iro>
parents:
701
diff
changeset
|
11 # (numpy array), each element will correspond to an available modality and the index list will indicate the weights |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
12 # associated to it). |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
13 # Exemple of index list: [1, 0, -3] |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
14 # *the 1 says that the first element of the input list will refer to the first element of the weights_list |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
15 # (auxiliary target as input) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
16 # if inputslist[i]>0 it refers to Weightslist[indexlist[i]-1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
17 # *the 0 means that the second element of the input list will not be encoded neither decoded (it is remplaced by zeros) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
18 # this is not efficient, so in this case it is better to give: [1,-3] and [inputslist[0],inputslist[2]] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
19 # but it allows us to deal with empty lists: give indexlist = numpy.asarray([.0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
20 # and inputlist=numpy.zeros((batchsize,1)) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
21 # *when an index is negative it means that the input will not be used for encoding but we will still reconstruct it |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
22 # (auxiliary target as output) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
23 # if inputslist[i]<0 it refers to Weightslist[-indexlist[i]-1] |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
24 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
25 # An entire batch should have the same available inputs configuration. |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
26 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
27 # Dense DAA Exemple:---------------------------------------------------------------------------- |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
28 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
29 #from theano.tensor.nnet import sigmoid |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
30 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
31 #nb_modality = 4 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
32 #wenc = [T.dmatrix('wenc%s'%i) for i in range(nb_modality)] |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
33 #wdec = [T.dmatrix('wdec%s'%i) for i in range(nb_modality)] |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
34 #benc = T.dvector('benc') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
35 #bdec = [T.dvector('bdec%s'%i) for i in range(nb_modality)] |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
36 #vectin = T.ivector('vectin') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
37 #inputpart = theano.generic('inputpart') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
38 #noise_bit = T.dscalar('noise_bit') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
39 #noise_group = T.dscalar('noise_group') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
40 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
41 #[vectin2,inputpart2] = scannoise(vectin,inputpart,noise_bit,noise_group) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
42 #hid = scandotenc(vectin2, inputpart2, wenc) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
43 #acthid = sigmoid(hid + benc) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
44 #dec = sigmoid(scanbiasdec(vectin2,inputpart2,bdec) + scandotdec(vectin2, inputpart2,acthid,wdec)) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
45 #cost = T.sum(T.sum(T.sqr( scaninput(vectin,inputpart) - rec ),1),0) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
46 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
47 # Checking inputs in make_node methods---------------------- |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
48 def Checkidx_list(idx_list): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
49 idx_list = T.as_tensor_variable(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
50 nidx = idx_list.type.ndim |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
51 if nidx != 1: raise TypeError('not vector', idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
52 return idx_list |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
53 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
54 def Checkhidd(hidd): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
55 hidd = T.as_tensor_variable(hidd) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
56 nhidd = hidd.type.ndim |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
57 if nhidd not in (1,2): raise TypeError('not matrix or vector', hidd) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
58 return hidd |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
59 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
60 def Checkweights_list(weights_list): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
61 weights_list = map(T.as_tensor_variable, weights_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
62 for i in range(len(weights_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
63 nweights = weights_list[i].type.ndim |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
64 if nweights not in (1,2): raise TypeError('not matrix or vector', weights_list[i]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
65 return weights_list |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
66 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
67 def Checkbias_list(bias_list): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
68 bias_list = map(T.as_tensor_variable, bias_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
69 for i in range(len(bias_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
70 nbias = bias_list[i].type.ndim |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
71 if nbias != 1: raise TypeError('not vector', bias_list[i]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
72 return bias_list |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
73 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
74 # Encoding scan dot product------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
75 class ScanDotEnc(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
76 """This Op takes an index list (as tensor.ivector), a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
77 the available inputs (as theano.generic), and all the encoding weights tensor.dmatrix of the model. It will select the |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
78 weights corresponding to the inputs (according to index list) and compute only the necessary dot products""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
79 def __init__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
80 #Create Theano methods to do the dot products with blas or at least in C. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
81 self.M=theano.Module() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
82 inputs = T.dmatrix('input') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
83 weights = T.dmatrix('weights') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
84 self.M.hid = T.dmatrix('hid') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
85 self.M.resultin = self.M.hid + T.dot(inputs,weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
86 result = T.dot(inputs,weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
87 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
88 self.M.dotin = theano.Method([inputs,weights],None,{self.M.hid : self.M.resultin}) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
89 self.M.dot = theano.Method([inputs,weights],result) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
90 self.m = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
91 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
92 def make_node(self, idx_list, inputs_list, weights_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
93 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
94 weights_list = Checkweights_list(weights_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
95 return Apply(self, [idx_list] + [inputs_list] + weights_list, [T.dmatrix()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
96 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
97 def perform(self, node, args, (hid,)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
98 idx_list = args[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
99 hidcalc = False |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
100 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
101 batchsize = (args[1][0].shape)[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
102 n_hid = (args[2].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
103 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
104 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
105 if max(idx_list) >= (len(args)-2)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
106 raise NotImplementedError('index superior to weight list length',idx_list) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
107 for a in args[1]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
108 if (a.shape)[0] != batchsize: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
109 raise NotImplementedError('different batchsize in the inputs list',a.shape) |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
110 for a in args[2:]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
111 if (a.shape)[1] != n_hid: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
112 raise NotImplementedError('different length of hidden in the weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
113 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
114 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
115 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
116 if hidcalc: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
117 self.m.dotin(args[1][i],args[2+int(idx_list[i]-1)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
118 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
119 self.m.hid = self.m.dot(args[1][i],args[2+int(idx_list[i]-1)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
120 hidcalc = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
121 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
122 if not hidcalc: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
123 hid[0] = numpy.zeros([batchsize,n_hid]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
124 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
125 hid[0] = self.m.hid |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
126 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
127 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
128 def grad(self, args, gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
129 gradi = ScanDotEncGrad()(args,gz) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
130 if type(gradi) != list: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
131 return [None, None] + [gradi] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
132 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
133 return [None, None] + gradi |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
134 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
135 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
136 return hash(ScanDotEnc)^58994 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
137 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
138 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
139 return "ScanDotEnc" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
140 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
141 scandotenc=ScanDotEnc() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
142 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
143 class ScanDotEncGrad(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
144 """This Op computes the gradient wrt the weights for ScanDotEnc""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
145 def __init__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
146 #Create Theano methods to do the dot products with blas or at least in C. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
147 self.M=theano.Module() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
148 input1 = T.dmatrix('input1') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
149 self.M.g_out = T.dmatrix('g_out') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
150 result = T.dmatrix('result') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
151 input2=T.transpose(input1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
152 self.M.resultin = result + T.dot(input2,self.M.g_out) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
153 self.M.result = T.dot(input2,self.M.g_out) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
154 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
155 self.M.dotin = theano.Method([input1,result],self.M.resultin) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
156 self.M.dot = theano.Method([input1],self.M.result) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
157 self.m = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
158 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
159 def make_node(self, args, g_out): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
160 idx_list = Checkidx_list(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
161 weights_list = Checkweights_list(args[2:]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
162 return Apply(self, args + g_out, [T.dmatrix() for i in xrange(2,len(args))]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
163 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
164 def perform(self, node, args, z): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
165 idx_list = args[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
166 self.m.g_out = args[-1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
167 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
168 batchsize = (args[1][0].shape)[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
169 n_hid = (args[2].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
170 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
171 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
172 if max(idx_list) >= (len(args)-3)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
173 raise NotImplementedError('index superior to weight list length',idx_list) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
174 for a in args[1]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
175 if (a.shape)[0] != batchsize: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
176 raise NotImplementedError('different batchsize in the inputs list',a.shape) |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
177 for a in args[2:-1]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
178 if (a.shape)[1] != n_hid: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
179 raise NotImplementedError('different length of hidden in the weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
180 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
181 zcalc = [False for i in range(len(args)-3)] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
182 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
183 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
184 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
185 if zcalc[int(idx_list[i]-1)]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
186 z[int(idx_list[i]-1)][0] = self.m.dotin(args[1][i],z[int(idx_list[i]-1)][0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
187 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
188 z[int(idx_list[i]-1)][0] = self.m.dot(args[1][i]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
189 zcalc[int(idx_list[i]-1)] = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
190 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
191 for i in range(len(args)-3): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
192 if not zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
193 shp = args[2+i].shape |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
194 z[i][0] = numpy.zeros(shp) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
195 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
196 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
197 return hash(ScanDotEncGrad)^15684 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
198 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
199 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
200 return "ScanDotEncGrad" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
201 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
202 # Decoding scan dot product------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
203 class ScanDotDec(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
204 """This Op takes an index list (as tensor.ivector), a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
205 the available inputs (as theano.generic), the hidden layer of the DAA (theano.dmatrix) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
206 and all the decoding weights tensor.dmatrix of the model. It will select the |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
207 weights corresponding to the available inputs (according to index list) and compute |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
208 only the necessary dot products. The outputs will be concatenated and will represent |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
209 the reconstruction of the different modality in the same order than the index list""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
210 def __init__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
211 #Create Theano methods to do the dot products with blas or at least in C. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
212 self.M=theano.Module() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
213 weights = T.dmatrix('weights') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
214 self.M.hid = T.dmatrix('hid') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
215 oldval = T.dmatrix('oldval') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
216 resultin = oldval + T.dot(self.M.hid,weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
217 result = T.dot(self.M.hid,weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
218 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
219 self.M.dotin = theano.Method([weights,oldval],resultin) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
220 self.M.dot = theano.Method([weights],result) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
221 self.m = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
222 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
223 def make_node(self, idx_list, input_list, hidd, weights_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
224 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
225 hidd = Checkhidd(hidd) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
226 weights_list = Checkweights_list(weights_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
227 return Apply(self, [idx_list] + [input_list] +[hidd] + weights_list,[T.dmatrix()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
228 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
229 def perform(self, node, args, (z,)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
230 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
231 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
232 self.m.hid = args[2] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
233 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
234 batchsize = (self.m.hid.shape)[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
235 n_hid = self.m.hid.shape[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
236 if max(idx_list) >= len(args)-3+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
237 raise NotImplementedError('index superior to weight list length',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
238 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
239 raise NotImplementedError('size of index different of inputs list size',idx_list) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
240 for a in args[3:]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
241 if (a.shape)[0] != n_hid: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
242 raise NotImplementedError('different length of hidden in the weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
243 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
244 zcalc = [False for i in idx_list] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
245 z[0] = [None for i in idx_list] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
246 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
247 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
248 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
249 if zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
250 z[0][i] = self.m.dotin(args[3+int(idx_list[i]-1)],z[0][i]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
251 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
252 z[0][i] = self.m.dot(args[3+int(idx_list[i]-1)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
253 zcalc[i] = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
254 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
255 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
256 if not zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
257 shp = args[1][int(idx_list[i]-1)].shape |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
258 z[0][i] = numpy.zeros((batchsize,shp[1])) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
259 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
260 z[0] = numpy.concatenate(z[0],1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
261 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
262 def grad(self, args, gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
263 gradi = ScanDotDecGrad()(args,gz) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
264 if type(gradi) != list: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
265 return [None, None] + [gradi] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
266 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
267 return [None, None] + gradi |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
268 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
269 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
270 return hash(ScanDotDec)^73568 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
271 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
272 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
273 return "ScanDotDec" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
274 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
275 scandotdec=ScanDotDec() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
276 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
277 class ScanDotDecGrad(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
278 """This Op computes the gradient wrt the weights for ScanDotDec""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
279 def __init__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
280 self.M=theano.Module() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
281 gout = T.dmatrix('gout') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
282 self.M.hidt = T.dmatrix('hid') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
283 oldval = T.dmatrix('oldval') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
284 resultin1 = oldval + T.dot(self.M.hidt,gout) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
285 result1 = T.dot(self.M.hidt,gout) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
286 weights = T.dmatrix('weights') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
287 weights2 = T.transpose(weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
288 resultin2 = oldval + T.dot(gout,weights2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
289 result2 = T.dot(gout,weights2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
290 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
291 self.M.dotin1 = theano.Method([gout,oldval],resultin1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
292 self.M.dot1 = theano.Method([gout],result1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
293 self.M.dotin2 = theano.Method([gout,weights,oldval],resultin2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
294 self.M.dot2 = theano.Method([gout,weights],result2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
295 self.m = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
296 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
297 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
298 def make_node(self, args, g_out): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
299 idx_list = Checkidx_list(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
300 hidd = Checkhidd(args[2]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
301 weights_list = Checkweights_list(args[3:]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
302 return Apply(self, args + g_out, [T.dmatrix() for i in xrange(2,len(args))]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
303 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
304 def perform(self, node, args, z): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
305 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
306 self.m.hidt = args[2].T |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
307 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
308 batchsize = (self.m.hidt.shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
309 n_hid = self.m.hidt.shape[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
310 if max(idx_list) >= len(args)-4+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
311 raise NotImplementedError('index superior to weight list length',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
312 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
313 raise NotImplementedError('size of index different of inputs list size',idx_list) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
314 for a in args[3:-1]: |
767
1e97e7c7f11f
very small opt.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
745
diff
changeset
|
315 if a.shape[0] != n_hid: |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
316 raise NotImplementedError('different length of hidden in the weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
317 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
318 zidx=numpy.zeros((len(idx_list)+1)) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
319 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
320 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
321 if idx_list[i] == 0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
322 zidx[i+1] = (args[1][i].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
323 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
324 zidx[i+1] = (args[3+idx_list[i]-1].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
325 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
326 zidx=zidx.cumsum() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
327 hidcalc = False |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
328 zcalc = [False for i in range((len(args)-4))] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
329 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
330 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
331 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
332 if zcalc[int(idx_list[i])-1]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
333 z[int(idx_list[i])][0] = self.m.dotin1(args[-1][:,zidx[i]:zidx[i+1]],z[int(idx_list[i])][0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
334 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
335 z[int(idx_list[i])][0] = self.m.dot1(args[-1][:,zidx[i]:zidx[i+1]]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
336 zcalc[int(idx_list[i])-1] = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
337 if hidcalc: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
338 z[0][0] = self.m.dotin2(args[-1][:,zidx[i]:zidx[i+1]],args[3+int(idx_list[i]-1)],z[0][0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
339 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
340 z[0][0] = self.m.dot2(args[-1][:,zidx[i]:zidx[i+1]],args[3+int(idx_list[i]-1)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
341 hidcalc = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
342 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
343 if not hidcalc: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
344 z[0][0] = numpy.zeros((self.m.hidt.shape[1],self.m.hidt.shape[0])) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
345 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
346 for i in range((len(args)-4)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
347 if not zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
348 shp = args[3+i].shape |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
349 z[i+1][0] = numpy.zeros(shp) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
350 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
351 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
352 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
353 return hash(ScanDotDecGrad)^87445 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
354 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
355 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
356 return "ScanDotDecGrad" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
357 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
358 # DAA input noise------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
359 class ScanNoise(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
360 """This Op takes an index list (as tensor.ivector), a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
361 the available inputs (as theano.generic), a probability of individual bit masking and |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
362 a probability of modality masking. It will return the inputs list with randoms zeros entry |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
363 and the index list with some positive values changed to negative values (groups masking)""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
364 def __init__(self, seed = 1): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
365 self.M=theano.Module() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
366 self.M.rand = T.RandomStreams(seed) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
367 self.seed = seed |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
368 mat = T.matrix('mat') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
369 noise_level_bit = T.dscalar('noise_level_bit') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
370 noise_level_group = T.dscalar('noise_level_group') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
371 self.M.out1 = self.M.rand.binomial(T.shape(mat), 1, 1 - noise_level_bit) * mat |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
372 self.M.out2 = self.M.rand.binomial((1,1), 1, 1 - noise_level_group) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
373 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
374 self.M.noisify_bit = theano.Method([mat,noise_level_bit],self.M.out1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
375 self.M.noisify_group_bool = theano.Method([noise_level_group],self.M.out2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
376 self.R = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
377 self.R.rand.initialize() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
378 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
379 def make_node(self, idx_list, inputs_list, noise_level_bit, noise_level_group): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
380 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
381 return Apply(self, [idx_list] + [inputs_list] + [noise_level_bit] + [noise_level_group],\ |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
382 [T.ivector(), theano.generic()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
383 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
384 def perform(self, node, (idx_list,inputs_list,noise_level_bit,noise_level_group), (y,z)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
385 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
386 if len(idx_list) != len(inputs_list) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
387 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
388 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
389 y[0] = numpy.asarray([-i if (i>0 and not(self.R.noisify_group_bool(noise_level_group))) else i for i in idx_list]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
390 z[0] = [(self.R.noisify_bit(inputs_list[i],noise_level_bit) if y[0][i]>0 else numpy.zeros((inputs_list[i].shape)))\ |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
391 for i in range(len(inputs_list))] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
392 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
393 def grad(self,args,gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
394 return [None,None,None,None] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
395 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
396 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
397 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
398 return hash(ScanNoise)^hash(self.seed)^hash(self.R.rand)^12254 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
399 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
400 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
401 return "ScanNoise" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
402 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
403 scannoise=ScanNoise() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
404 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
405 # Total input matrix construction------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
406 class ScanInputs(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
407 """This Op takes an index list (as tensor.ivector) and a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
408 the available inputs (as theano.generic). It will construct the appropriate tensor.dmatrix |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
409 to compare to the reconstruction obtained with ScanDotDec""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
410 def make_node(self, idx_list, inputs_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
411 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
412 return Apply(self, [idx_list] + [inputs_list],[T.dmatrix()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
413 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
414 def perform(self, node, (idx_list, inputs_list), (z,)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
415 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
416 if len(idx_list) != len(inputs_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
417 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
418 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
419 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
420 if idx_list[i] == 0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
421 inputs_list[i] = 0 * inputs_list[i] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
422 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
423 z[0] = numpy.concatenate(inputs_list,1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
424 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
425 def grad(self,args,gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
426 return [None,None] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
427 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
428 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
429 return hash(ScanInputs)^75902 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
430 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
431 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
432 return "ScanInputs" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
433 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
434 scaninputs=ScanInputs() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
435 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
436 # Decoding bias vector construction------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
437 class ScanBiasDec(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
438 """This Op takes an index list (as tensor.ivector), a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
439 the available inputs (as theano.generic) and the decoding bias tensor.dvector. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
440 It will construct the appropriate bias tensor.dvector |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
441 to add to the reconstruction obtained with ScanDotDec""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
442 def make_node(self, idx_list, input_list, bias_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
443 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
444 bias_list = Checkbias_list(bias_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
445 return Apply(self, [idx_list] + [input_list] + bias_list, [T.dvector()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
446 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
447 def perform(self, node, args, (z,)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
448 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
449 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
450 if max(idx_list) >= (len(args)-2)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
451 raise NotImplementedError('index superior to bias list length',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
452 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
453 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
454 z[0] = [args[idx_list[i]+1] if idx_list[i] != 0 else numpy.zeros(args[1][i].shape[1]) \ |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
455 for i in range(len(idx_list))] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
456 z[0] = numpy.concatenate(z[0],1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
457 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
458 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
459 return hash(ScanBiasDec)^60056 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
460 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
461 def grad(self,args,gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
462 gradi = ScanBiasDecGrad()(args,gz) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
463 if type(gradi) != list: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
464 return [None, None] + [gradi] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
465 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
466 return [None, None] + gradi |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
467 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
468 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
469 return "ScanBiasDec" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
470 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
471 scanbiasdec=ScanBiasDec() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
472 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
473 class ScanBiasDecGrad(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
474 """This Op computes the gradient wrt the bias for ScanBiasDec""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
475 def make_node(self, args, g_out): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
476 idx_list = Checkidx_list(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
477 bias_list = Checkbias_list(args[2:]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
478 return Apply(self, args + g_out, [T.dvector() for i in range(len(args)-2)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
479 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
480 def perform(self, node, args, z): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
481 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
482 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
483 if max(idx_list) >= (len(args)-3)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
484 raise NotImplementedError('index superior to bias list length',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
485 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
486 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
487 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
488 zidx=numpy.zeros((len(idx_list)+1)) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
489 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
490 if idx_list[i] == 0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
491 zidx[i+1] = (args[1][i].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
492 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
493 zidx[i+1] = (args[2+idx_list[i]-1].size) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
494 zidx=zidx.cumsum() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
495 zcalc = [False for i in range((len(args)-3))] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
496 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
497 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
498 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
499 if zcalc[int(idx_list[i])-1]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
500 z[int(idx_list[i])-1][0] += args[-1][zidx[i]:zidx[i+1]] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
501 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
502 z[int(idx_list[i])-1][0] = args[-1][zidx[i]:zidx[i+1]] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
503 zcalc[int(idx_list[i])-1] = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
504 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
505 for i in range((len(args)-3)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
506 if not zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
507 shp = args[2+i].size |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
508 z[i][0] = numpy.zeros(shp) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
509 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
510 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
511 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
512 return hash(ScanBiasDecGrad)^41256 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
513 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
514 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
515 return "ScanBiasDecGrad" |
694
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
516 |
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
517 # Mask construction------------------------------------ |
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
518 class ScanMask(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
519 """This Op takes an index list (as tensor.ivector) and a list of weigths. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
520 It will construct a list of T.iscalar representing the Mask |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
521 to do the correct regularisation on the weigths""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
522 def __init__(self,encbool=True): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
523 self.encbool = encbool |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
524 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
525 def make_node(self, idx_list, weights_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
526 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
527 weights_list = Checkweights_list(weights_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
528 return Apply(self, [idx_list] + weights_list, [T.iscalar() for i in range(len(weights_list))]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
529 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
530 def perform(self, node, args, z): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
531 if self.encbool: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
532 idx_list = args[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
533 dim = 1 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
534 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
535 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
536 dim = 0 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
537 n_hid = args[1].shape[dim] |
701
113946723973
fixed bug of scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
700
diff
changeset
|
538 |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
539 if max(idx_list) >= (len(args)-1)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
540 raise NotImplementedError('index superior to weights list length',idx_listdec) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
541 for a in args[1:]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
542 if a.shape[dim] != n_hid: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
543 raise NotImplementedError('different length of hidden in the encoding weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
544 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
545 for i in range(len(args[1:])): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
546 z[i][0] = numpy.asarray((idx_list == i+1).sum(),dtype='int32') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
547 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
548 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
549 return hash(ScanMask)^hash(self.encbool)^11447 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
550 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
551 def grad(self,args,gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
552 return [None] * len(args) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
553 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
554 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
555 if self.encbool: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
556 string = "Enc" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
557 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
558 string = "Dec" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
559 return "ScanMask" + string |
694
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
560 |
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
561 scanmaskenc=ScanMask(True) |
711
0eae6d5315b5
Fixed minor typo in comment
Olivier Delalleau <delallea@iro>
parents:
701
diff
changeset
|
562 scanmaskdec=ScanMask(False) |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
563 |
717
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
564 # TODO The classes FillMissing and MaskSelect below should probably be moved |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
565 # to another (more appropriate) file. |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
566 class FillMissing(Op): |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
567 """ |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
568 Given an input, output two elements: |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
569 - a copy of the input where missing values (NaN) are replaced by some |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
570 other value (zero by default) |
720
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
571 - a mask of the same size and type as input, where each element is zero |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
572 iff the corresponding input is missing |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
573 The 'fill_with' parameter may either be: |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
574 - a scalar: all missing values are replaced with this value |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
575 - a Numpy array: a missing value is replaced by the value in this array |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
576 at the same position (ignoring the first k dimensions if 'fill_with' |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
577 has k less dimensions than the input) |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
578 Currently, the gradient is computed as if the input value was really what |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
579 it was replaced with. It may be safer to replace the gradient w.r.t. |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
580 missing values with either zeros or missing values (?). |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
581 """ |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
582 |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
583 def __init__(self, fill_with=0): |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
584 super(Op, self).__init__() |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
585 self.fill_with = fill_with |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
586 self.fill_with_is_array = isinstance(self.fill_with, numpy.ndarray) |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
587 |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
588 def __eq__(self, other): |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
589 return (type(self) == type(other) and |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
590 self.fill_with_is_array == other.fill_with_is_array and |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
591 ((self.fill_with_is_array and |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
592 (self.fill_with == other.fill_with).all()) or |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
593 self.fill_with == other.fill_with)) |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
594 |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
595 def __hash__(self): |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
596 if self.fill_with_is_array: |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
597 fill_hash = self.fill_with.__hash__() |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
598 else: |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
599 fill_hash = hash(self.fill_with) |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
600 return hash(type(self))^hash(self.fill_with_is_array)^fill_hash |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
601 |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
602 def make_node(self, input): |
720
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
603 return Apply(self, [input], [input.type(), input.type()]) |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
604 |
724
d42b4bcbb582
Replaced debug special code for missing values (-123456) by truly missing (NaN)
Olivier Delalleau <delallea@iro>
parents:
720
diff
changeset
|
605 def perform(self, node, (input, ), output_storage): |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
606 out = output_storage[0] |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
607 out[0] = input.copy() |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
608 out = out[0] |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
609 mask = output_storage[1] |
720
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
610 mask[0] = numpy.ones(input.shape) |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
611 mask = mask[0] |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
612 if self.fill_with_is_array: |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
613 ignore_k = len(out.shape) - len(self.fill_with.shape) |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
614 assert ignore_k >= 0 |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
615 for (idx, v) in numpy.ndenumerate(out): |
724
d42b4bcbb582
Replaced debug special code for missing values (-123456) by truly missing (NaN)
Olivier Delalleau <delallea@iro>
parents:
720
diff
changeset
|
616 if numpy.isnan(v): |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
617 if self.fill_with_is_array: |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
618 out[idx] = self.fill_with[idx[ignore_k:]] |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
619 else: |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
620 out[idx] = self.fill_with |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
621 mask[idx] = 0 |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
622 |
720
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
623 def grad(self, inputs, (out_grad, mask_grad, )): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
624 return [out_grad] |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
625 |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
626 fill_missing_with_zeros = FillMissing(0) |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
627 |
720
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
628 class MaskGradient(Op): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
629 """ |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
630 Takes as input a tensor and a mask. Outputs the same tensor, but setting |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
631 to zero the gradient for all elements where the mask's value is zero. |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
632 """ |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
633 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
634 def __eq__(self, other): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
635 return type(self) == type(other) |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
636 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
637 def __hash__(self): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
638 return hash(type(self)) |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
639 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
640 def make_node(self, input, mask): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
641 return Apply(self, [input, mask], [input.type()]) |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
642 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
643 def perform(self, node, (input, mask), (output, )): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
644 output[0] = input.copy() |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
645 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
646 def grad(self, (input, mask), (out_grad, )): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
647 return [out_grad * T.neq(mask, 0), None] |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
648 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
649 mask_gradient = MaskGradient() |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
650 |
717
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
651 class MaskSelect(Op): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
652 """ |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
653 Given an input x and a mask m (both vectors), outputs a vector that |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
654 contains all elements x[i] such that bool(m[i]) is True. |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
655 """ |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
656 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
657 def __eq__(self, other): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
658 return type(self) == type(other) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
659 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
660 def __hash__(self): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
661 return hash(type(self)) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
662 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
663 def make_node(self, input, mask): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
664 return Apply(self, [input, mask], [input.type()]) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
665 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
666 def perform(self, node, (input, mask), (output, )): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
667 select = [] |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
668 for (i, m) in enumerate(mask): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
669 if bool(m): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
670 select.append(i) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
671 output[0] = numpy.zeros(len(select), dtype = input.dtype) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
672 out = output[0] |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
673 for (i, j) in enumerate(select): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
674 out[i] = input[j] |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
675 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
676 mask_select = MaskSelect() |