Mercurial > pylearn
annotate pylearn/sandbox/scan_inputs_groups.py @ 1530:08b3e827575a
Change class path as Theano want to hide the Deprecated Module interface.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Fri, 12 Jul 2013 13:46:45 -0400 |
parents | d15683416ebf |
children |
rev | line source |
---|---|
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
1 import numpy |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
2 import theano |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
3 from theano import tensor as T |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
4 from theano.gof import Op |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
5 from theano.gof import Apply |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
6 from theano import scalar as scal |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
7 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
8 # These Ops allows us to deal with static groups of possibly missing inputs efficiently in the dense DAA framework |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
9 # (for exemple with multimodal data with sometimes entire modality missing). |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
10 # The inputs will be represented with an index list and a theano.generic variable (which will be a list of matrices |
711
0eae6d5315b5
Fixed minor typo in comment
Olivier Delalleau <delallea@iro>
parents:
701
diff
changeset
|
11 # (numpy array), each element will correspond to an available modality and the index list will indicate the weights |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
12 # associated to it). |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
13 # Exemple of index list: [1, 0, -3] |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
14 # *the 1 says that the first element of the input list will refer to the first element of the weights_list |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
15 # (auxiliary target as input) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
16 # if inputslist[i]>0 it refers to Weightslist[indexlist[i]-1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
17 # *the 0 means that the second element of the input list will not be encoded neither decoded (it is remplaced by zeros) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
18 # this is not efficient, so in this case it is better to give: [1,-3] and [inputslist[0],inputslist[2]] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
19 # but it allows us to deal with empty lists: give indexlist = numpy.asarray([.0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
20 # and inputlist=numpy.zeros((batchsize,1)) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
21 # *when an index is negative it means that the input will not be used for encoding but we will still reconstruct it |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
22 # (auxiliary target as output) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
23 # if inputslist[i]<0 it refers to Weightslist[-indexlist[i]-1] |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
24 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
25 # An entire batch should have the same available inputs configuration. |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
26 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
27 # Dense DAA Exemple:---------------------------------------------------------------------------- |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
28 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
29 #from theano.tensor.nnet import sigmoid |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
30 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
31 #nb_modality = 4 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
32 #wenc = [T.dmatrix('wenc%s'%i) for i in range(nb_modality)] |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
33 #wdec = [T.dmatrix('wdec%s'%i) for i in range(nb_modality)] |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
34 #benc = T.dvector('benc') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
35 #bdec = [T.dvector('bdec%s'%i) for i in range(nb_modality)] |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
36 #vectin = T.ivector('vectin') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
37 #inputpart = theano.generic('inputpart') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
38 #noise_bit = T.dscalar('noise_bit') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
39 #noise_group = T.dscalar('noise_group') |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
40 # |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
41 #[vectin2,inputpart2] = scannoise(vectin,inputpart,noise_bit,noise_group) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
42 #hid = scandotenc(vectin2, inputpart2, wenc) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
43 #acthid = sigmoid(hid + benc) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
44 #dec = sigmoid(scanbiasdec(vectin2,inputpart2,bdec) + scandotdec(vectin2, inputpart2,acthid,wdec)) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
45 #cost = T.sum(T.sum(T.sqr( scaninput(vectin,inputpart) - rec ),1),0) |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
46 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
47 # Checking inputs in make_node methods---------------------- |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
48 def Checkidx_list(idx_list): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
49 idx_list = T.as_tensor_variable(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
50 nidx = idx_list.type.ndim |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
51 if nidx != 1: raise TypeError('not vector', idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
52 return idx_list |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
53 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
54 def Checkhidd(hidd): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
55 hidd = T.as_tensor_variable(hidd) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
56 nhidd = hidd.type.ndim |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
57 if nhidd not in (1,2): raise TypeError('not matrix or vector', hidd) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
58 return hidd |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
59 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
60 def Checkweights_list(weights_list): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
61 weights_list = map(T.as_tensor_variable, weights_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
62 for i in range(len(weights_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
63 nweights = weights_list[i].type.ndim |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
64 if nweights not in (1,2): raise TypeError('not matrix or vector', weights_list[i]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
65 return weights_list |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
66 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
67 def Checkbias_list(bias_list): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
68 bias_list = map(T.as_tensor_variable, bias_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
69 for i in range(len(bias_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
70 nbias = bias_list[i].type.ndim |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
71 if nbias != 1: raise TypeError('not vector', bias_list[i]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
72 return bias_list |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
73 |
817
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
74 |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
75 # block grad Op------------------------------------ |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
76 class BlockGrad(Op): |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
77 """This Op block the gradient of a variable""" |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
78 def make_node(self, x): |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
79 x = T.as_tensor_variable(x) |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
80 if x.ndim == 1: |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
81 return Apply(self, [x], [T.dvector()]) |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
82 else: |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
83 return Apply(self, [x], [T.dmatrix()]) |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
84 |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
85 def perform(self, node , x ,(out,)): |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
86 out[0] = x[0].copy() |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
87 |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
88 def grad(self, x, (gx,)): |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
89 return [gx*0] |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
90 |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
91 def __hash__(self): |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
92 return hash(BlockGrad)^77612 |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
93 |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
94 def __str__(self): |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
95 return "BlockGrad" |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
96 |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
97 blockgrad=BlockGrad() |
db2c26a2c97c
new parameters and Op for DAA inputs groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
789
diff
changeset
|
98 |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
99 # Encoding scan dot product------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
100 class ScanDotEnc(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
101 """This Op takes an index list (as tensor.ivector), a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
102 the available inputs (as theano.generic), and all the encoding weights tensor.dmatrix of the model. It will select the |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
103 weights corresponding to the inputs (according to index list) and compute only the necessary dot products""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
104 def __init__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
105 #Create Theano methods to do the dot products with blas or at least in C. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
106 self.M=theano.Module() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
107 inputs = T.dmatrix('input') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
108 weights = T.dmatrix('weights') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
109 self.M.hid = T.dmatrix('hid') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
110 self.M.resultin = self.M.hid + T.dot(inputs,weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
111 result = T.dot(inputs,weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
112 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
113 self.M.dotin = theano.Method([inputs,weights],None,{self.M.hid : self.M.resultin}) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
114 self.M.dot = theano.Method([inputs,weights],result) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
115 self.m = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
116 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
117 def make_node(self, idx_list, inputs_list, weights_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
118 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
119 weights_list = Checkweights_list(weights_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
120 return Apply(self, [idx_list] + [inputs_list] + weights_list, [T.dmatrix()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
121 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
122 def perform(self, node, args, (hid,)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
123 idx_list = args[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
124 hidcalc = False |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
125 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
126 batchsize = (args[1][0].shape)[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
127 n_hid = (args[2].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
128 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
129 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
130 if max(idx_list) >= (len(args)-2)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
131 raise NotImplementedError('index superior to weight list length',idx_list) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
132 for a in args[1]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
133 if (a.shape)[0] != batchsize: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
134 raise NotImplementedError('different batchsize in the inputs list',a.shape) |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
135 for a in args[2:]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
136 if (a.shape)[1] != n_hid: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
137 raise NotImplementedError('different length of hidden in the weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
138 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
139 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
140 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
141 if hidcalc: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
142 self.m.dotin(args[1][i],args[2+int(idx_list[i]-1)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
143 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
144 self.m.hid = self.m.dot(args[1][i],args[2+int(idx_list[i]-1)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
145 hidcalc = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
146 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
147 if not hidcalc: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
148 hid[0] = numpy.zeros([batchsize,n_hid]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
149 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
150 hid[0] = self.m.hid |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
151 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
152 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
153 def grad(self, args, gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
154 gradi = ScanDotEncGrad()(args,gz) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
155 if type(gradi) != list: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
156 return [None, None] + [gradi] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
157 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
158 return [None, None] + gradi |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
159 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
160 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
161 return hash(ScanDotEnc)^58994 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
162 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
163 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
164 return "ScanDotEnc" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
165 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
166 scandotenc=ScanDotEnc() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
167 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
168 class ScanDotEncGrad(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
169 """This Op computes the gradient wrt the weights for ScanDotEnc""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
170 def __init__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
171 #Create Theano methods to do the dot products with blas or at least in C. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
172 self.M=theano.Module() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
173 input1 = T.dmatrix('input1') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
174 self.M.g_out = T.dmatrix('g_out') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
175 result = T.dmatrix('result') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
176 input2=T.transpose(input1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
177 self.M.resultin = result + T.dot(input2,self.M.g_out) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
178 self.M.result = T.dot(input2,self.M.g_out) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
179 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
180 self.M.dotin = theano.Method([input1,result],self.M.resultin) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
181 self.M.dot = theano.Method([input1],self.M.result) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
182 self.m = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
183 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
184 def make_node(self, args, g_out): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
185 idx_list = Checkidx_list(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
186 weights_list = Checkweights_list(args[2:]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
187 return Apply(self, args + g_out, [T.dmatrix() for i in xrange(2,len(args))]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
188 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
189 def perform(self, node, args, z): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
190 idx_list = args[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
191 self.m.g_out = args[-1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
192 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
193 batchsize = (args[1][0].shape)[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
194 n_hid = (args[2].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
195 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
196 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
197 if max(idx_list) >= (len(args)-3)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
198 raise NotImplementedError('index superior to weight list length',idx_list) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
199 for a in args[1]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
200 if (a.shape)[0] != batchsize: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
201 raise NotImplementedError('different batchsize in the inputs list',a.shape) |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
202 for a in args[2:-1]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
203 if (a.shape)[1] != n_hid: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
204 raise NotImplementedError('different length of hidden in the weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
205 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
206 zcalc = [False for i in range(len(args)-3)] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
207 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
208 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
209 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
210 if zcalc[int(idx_list[i]-1)]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
211 z[int(idx_list[i]-1)][0] = self.m.dotin(args[1][i],z[int(idx_list[i]-1)][0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
212 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
213 z[int(idx_list[i]-1)][0] = self.m.dot(args[1][i]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
214 zcalc[int(idx_list[i]-1)] = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
215 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
216 for i in range(len(args)-3): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
217 if not zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
218 shp = args[2+i].shape |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
219 z[i][0] = numpy.zeros(shp) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
220 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
221 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
222 return hash(ScanDotEncGrad)^15684 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
223 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
224 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
225 return "ScanDotEncGrad" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
226 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
227 # Decoding scan dot product------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
228 class ScanDotDec(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
229 """This Op takes an index list (as tensor.ivector), a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
230 the available inputs (as theano.generic), the hidden layer of the DAA (theano.dmatrix) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
231 and all the decoding weights tensor.dmatrix of the model. It will select the |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
232 weights corresponding to the available inputs (according to index list) and compute |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
233 only the necessary dot products. The outputs will be concatenated and will represent |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
234 the reconstruction of the different modality in the same order than the index list""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
235 def __init__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
236 #Create Theano methods to do the dot products with blas or at least in C. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
237 self.M=theano.Module() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
238 weights = T.dmatrix('weights') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
239 self.M.hid = T.dmatrix('hid') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
240 oldval = T.dmatrix('oldval') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
241 resultin = oldval + T.dot(self.M.hid,weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
242 result = T.dot(self.M.hid,weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
243 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
244 self.M.dotin = theano.Method([weights,oldval],resultin) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
245 self.M.dot = theano.Method([weights],result) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
246 self.m = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
247 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
248 def make_node(self, idx_list, input_list, hidd, weights_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
249 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
250 hidd = Checkhidd(hidd) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
251 weights_list = Checkweights_list(weights_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
252 return Apply(self, [idx_list] + [input_list] +[hidd] + weights_list,[T.dmatrix()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
253 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
254 def perform(self, node, args, (z,)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
255 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
256 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
257 self.m.hid = args[2] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
258 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
259 batchsize = (self.m.hid.shape)[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
260 n_hid = self.m.hid.shape[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
261 if max(idx_list) >= len(args)-3+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
262 raise NotImplementedError('index superior to weight list length',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
263 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
264 raise NotImplementedError('size of index different of inputs list size',idx_list) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
265 for a in args[3:]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
266 if (a.shape)[0] != n_hid: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
267 raise NotImplementedError('different length of hidden in the weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
268 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
269 zcalc = [False for i in idx_list] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
270 z[0] = [None for i in idx_list] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
271 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
272 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
273 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
274 if zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
275 z[0][i] = self.m.dotin(args[3+int(idx_list[i]-1)],z[0][i]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
276 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
277 z[0][i] = self.m.dot(args[3+int(idx_list[i]-1)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
278 zcalc[i] = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
279 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
280 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
281 if not zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
282 shp = args[1][int(idx_list[i]-1)].shape |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
283 z[0][i] = numpy.zeros((batchsize,shp[1])) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
284 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
285 z[0] = numpy.concatenate(z[0],1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
286 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
287 def grad(self, args, gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
288 gradi = ScanDotDecGrad()(args,gz) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
289 if type(gradi) != list: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
290 return [None, None] + [gradi] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
291 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
292 return [None, None] + gradi |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
293 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
294 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
295 return hash(ScanDotDec)^73568 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
296 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
297 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
298 return "ScanDotDec" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
299 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
300 scandotdec=ScanDotDec() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
301 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
302 class ScanDotDecGrad(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
303 """This Op computes the gradient wrt the weights for ScanDotDec""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
304 def __init__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
305 self.M=theano.Module() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
306 gout = T.dmatrix('gout') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
307 self.M.hidt = T.dmatrix('hid') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
308 oldval = T.dmatrix('oldval') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
309 resultin1 = oldval + T.dot(self.M.hidt,gout) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
310 result1 = T.dot(self.M.hidt,gout) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
311 weights = T.dmatrix('weights') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
312 weights2 = T.transpose(weights) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
313 resultin2 = oldval + T.dot(gout,weights2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
314 result2 = T.dot(gout,weights2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
315 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
316 self.M.dotin1 = theano.Method([gout,oldval],resultin1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
317 self.M.dot1 = theano.Method([gout],result1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
318 self.M.dotin2 = theano.Method([gout,weights,oldval],resultin2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
319 self.M.dot2 = theano.Method([gout,weights],result2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
320 self.m = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
321 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
322 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
323 def make_node(self, args, g_out): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
324 idx_list = Checkidx_list(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
325 hidd = Checkhidd(args[2]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
326 weights_list = Checkweights_list(args[3:]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
327 return Apply(self, args + g_out, [T.dmatrix() for i in xrange(2,len(args))]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
328 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
329 def perform(self, node, args, z): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
330 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
331 self.m.hidt = args[2].T |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
332 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
333 batchsize = (self.m.hidt.shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
334 n_hid = self.m.hidt.shape[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
335 if max(idx_list) >= len(args)-4+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
336 raise NotImplementedError('index superior to weight list length',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
337 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
338 raise NotImplementedError('size of index different of inputs list size',idx_list) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
339 for a in args[3:-1]: |
767
1e97e7c7f11f
very small opt.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
745
diff
changeset
|
340 if a.shape[0] != n_hid: |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
341 raise NotImplementedError('different length of hidden in the weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
342 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
343 zidx=numpy.zeros((len(idx_list)+1)) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
344 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
345 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
346 if idx_list[i] == 0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
347 zidx[i+1] = (args[1][i].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
348 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
349 zidx[i+1] = (args[3+idx_list[i]-1].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
350 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
351 zidx=zidx.cumsum() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
352 hidcalc = False |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
353 zcalc = [False for i in range((len(args)-4))] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
354 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
355 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
356 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
357 if zcalc[int(idx_list[i])-1]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
358 z[int(idx_list[i])][0] = self.m.dotin1(args[-1][:,zidx[i]:zidx[i+1]],z[int(idx_list[i])][0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
359 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
360 z[int(idx_list[i])][0] = self.m.dot1(args[-1][:,zidx[i]:zidx[i+1]]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
361 zcalc[int(idx_list[i])-1] = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
362 if hidcalc: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
363 z[0][0] = self.m.dotin2(args[-1][:,zidx[i]:zidx[i+1]],args[3+int(idx_list[i]-1)],z[0][0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
364 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
365 z[0][0] = self.m.dot2(args[-1][:,zidx[i]:zidx[i+1]],args[3+int(idx_list[i]-1)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
366 hidcalc = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
367 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
368 if not hidcalc: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
369 z[0][0] = numpy.zeros((self.m.hidt.shape[1],self.m.hidt.shape[0])) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
370 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
371 for i in range((len(args)-4)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
372 if not zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
373 shp = args[3+i].shape |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
374 z[i+1][0] = numpy.zeros(shp) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
375 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
376 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
377 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
378 return hash(ScanDotDecGrad)^87445 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
379 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
380 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
381 return "ScanDotDecGrad" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
382 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
383 # DAA input noise------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
384 class ScanNoise(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
385 """This Op takes an index list (as tensor.ivector), a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
386 the available inputs (as theano.generic), a probability of individual bit masking and |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
387 a probability of modality masking. It will return the inputs list with randoms zeros entry |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
388 and the index list with some positive values changed to negative values (groups masking)""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
389 def __init__(self, seed = 1): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
390 self.M=theano.Module() |
1530
08b3e827575a
Change class path as Theano want to hide the Deprecated Module interface.
Frederic Bastien <nouiz@nouiz.org>
parents:
852
diff
changeset
|
391 self.M.rand = T.randomstreams.RandomStreams(seed) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
392 self.seed = seed |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
393 mat = T.matrix('mat') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
394 noise_level_bit = T.dscalar('noise_level_bit') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
395 noise_level_group = T.dscalar('noise_level_group') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
396 self.M.out1 = self.M.rand.binomial(T.shape(mat), 1, 1 - noise_level_bit) * mat |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
397 self.M.out2 = self.M.rand.binomial((1,1), 1, 1 - noise_level_group) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
398 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
399 self.M.noisify_bit = theano.Method([mat,noise_level_bit],self.M.out1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
400 self.M.noisify_group_bool = theano.Method([noise_level_group],self.M.out2) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
401 self.R = self.M.make() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
402 self.R.rand.initialize() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
403 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
404 def make_node(self, idx_list, inputs_list, noise_level_bit, noise_level_group): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
405 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
406 return Apply(self, [idx_list] + [inputs_list] + [noise_level_bit] + [noise_level_group],\ |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
407 [T.ivector(), theano.generic()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
408 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
409 def perform(self, node, (idx_list,inputs_list,noise_level_bit,noise_level_group), (y,z)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
410 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
411 if len(idx_list) != len(inputs_list) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
412 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
413 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
414 y[0] = numpy.asarray([-i if (i>0 and not(self.R.noisify_group_bool(noise_level_group))) else i for i in idx_list]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
415 z[0] = [(self.R.noisify_bit(inputs_list[i],noise_level_bit) if y[0][i]>0 else numpy.zeros((inputs_list[i].shape)))\ |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
416 for i in range(len(inputs_list))] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
417 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
418 def grad(self,args,gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
419 return [None,None,None,None] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
420 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
421 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
422 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
423 return hash(ScanNoise)^hash(self.seed)^hash(self.R.rand)^12254 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
424 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
425 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
426 return "ScanNoise" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
427 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
428 scannoise=ScanNoise() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
429 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
430 # Total input matrix construction------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
431 class ScanInputs(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
432 """This Op takes an index list (as tensor.ivector) and a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
433 the available inputs (as theano.generic). It will construct the appropriate tensor.dmatrix |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
434 to compare to the reconstruction obtained with ScanDotDec""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
435 def make_node(self, idx_list, inputs_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
436 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
437 return Apply(self, [idx_list] + [inputs_list],[T.dmatrix()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
438 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
439 def perform(self, node, (idx_list, inputs_list), (z,)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
440 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
441 if len(idx_list) != len(inputs_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
442 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
443 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
444 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
445 if idx_list[i] == 0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
446 inputs_list[i] = 0 * inputs_list[i] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
447 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
448 z[0] = numpy.concatenate(inputs_list,1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
449 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
450 def grad(self,args,gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
451 return [None,None] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
452 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
453 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
454 return hash(ScanInputs)^75902 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
455 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
456 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
457 return "ScanInputs" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
458 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
459 scaninputs=ScanInputs() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
460 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
461 # Decoding bias vector construction------------------------------------ |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
462 class ScanBiasDec(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
463 """This Op takes an index list (as tensor.ivector), a list of matrices representing |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
464 the available inputs (as theano.generic) and the decoding bias tensor.dvector. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
465 It will construct the appropriate bias tensor.dvector |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
466 to add to the reconstruction obtained with ScanDotDec""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
467 def make_node(self, idx_list, input_list, bias_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
468 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
469 bias_list = Checkbias_list(bias_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
470 return Apply(self, [idx_list] + [input_list] + bias_list, [T.dvector()]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
471 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
472 def perform(self, node, args, (z,)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
473 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
474 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
475 if max(idx_list) >= (len(args)-2)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
476 raise NotImplementedError('index superior to bias list length',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
477 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
478 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
479 z[0] = [args[idx_list[i]+1] if idx_list[i] != 0 else numpy.zeros(args[1][i].shape[1]) \ |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
480 for i in range(len(idx_list))] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
481 z[0] = numpy.concatenate(z[0],1) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
482 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
483 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
484 return hash(ScanBiasDec)^60056 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
485 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
486 def grad(self,args,gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
487 gradi = ScanBiasDecGrad()(args,gz) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
488 if type(gradi) != list: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
489 return [None, None] + [gradi] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
490 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
491 return [None, None] + gradi |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
492 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
493 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
494 return "ScanBiasDec" |
686
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
495 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
496 scanbiasdec=ScanBiasDec() |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
497 |
0457dfa6fcad
add direclty the file scan_inputs_groups.py to sandbox and remove the directory input_groups
Foo Bar <barfoo@iro.umontreal.ca>
parents:
diff
changeset
|
498 class ScanBiasDecGrad(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
499 """This Op computes the gradient wrt the bias for ScanBiasDec""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
500 def make_node(self, args, g_out): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
501 idx_list = Checkidx_list(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
502 bias_list = Checkbias_list(args[2:]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
503 return Apply(self, args + g_out, [T.dvector() for i in range(len(args)-2)]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
504 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
505 def perform(self, node, args, z): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
506 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
507 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
508 if max(idx_list) >= (len(args)-3)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
509 raise NotImplementedError('index superior to bias list length',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
510 if len(idx_list) != len(args[1]) : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
511 raise NotImplementedError('size of index different of inputs list size',idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
512 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
513 zidx=numpy.zeros((len(idx_list)+1)) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
514 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
515 if idx_list[i] == 0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
516 zidx[i+1] = (args[1][i].shape)[1] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
517 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
518 zidx[i+1] = (args[2+idx_list[i]-1].size) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
519 zidx=zidx.cumsum() |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
520 zcalc = [False for i in range((len(args)-3))] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
521 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
522 for i in range(len(idx_list)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
523 if idx_list[i]>0: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
524 if zcalc[int(idx_list[i])-1]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
525 z[int(idx_list[i])-1][0] += args[-1][zidx[i]:zidx[i+1]] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
526 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
527 z[int(idx_list[i])-1][0] = args[-1][zidx[i]:zidx[i+1]] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
528 zcalc[int(idx_list[i])-1] = True |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
529 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
530 for i in range((len(args)-3)): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
531 if not zcalc[i]: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
532 shp = args[2+i].size |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
533 z[i][0] = numpy.zeros(shp) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
534 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
535 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
536 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
537 return hash(ScanBiasDecGrad)^41256 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
538 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
539 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
540 return "ScanBiasDecGrad" |
694
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
541 |
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
542 # Mask construction------------------------------------ |
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
543 class ScanMask(Op): |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
544 """This Op takes an index list (as tensor.ivector) and a list of weigths. |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
545 It will construct a list of T.iscalar representing the Mask |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
546 to do the correct regularisation on the weigths""" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
547 def __init__(self,encbool=True): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
548 self.encbool = encbool |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
549 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
550 def make_node(self, idx_list, weights_list): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
551 idx_list = Checkidx_list(idx_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
552 weights_list = Checkweights_list(weights_list) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
553 return Apply(self, [idx_list] + weights_list, [T.iscalar() for i in range(len(weights_list))]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
554 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
555 def perform(self, node, args, z): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
556 if self.encbool: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
557 idx_list = args[0] |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
558 dim = 1 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
559 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
560 idx_list = abs(args[0]) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
561 dim = 0 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
562 n_hid = args[1].shape[dim] |
701
113946723973
fixed bug of scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
700
diff
changeset
|
563 |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
564 if max(idx_list) >= (len(args)-1)+1 : |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
565 raise NotImplementedError('index superior to weights list length',idx_listdec) |
768
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
566 for a in args[1:]: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
567 if a.shape[dim] != n_hid: |
bd95e8ea99d8
small change to Checks of scan_inputs_groups ops
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
767
diff
changeset
|
568 raise NotImplementedError('different length of hidden in the encoding weights list',a.shape) |
714
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
569 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
570 for i in range(len(args[1:])): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
571 z[i][0] = numpy.asarray((idx_list == i+1).sum(),dtype='int32') |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
572 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
573 def __hash__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
574 return hash(ScanMask)^hash(self.encbool)^11447 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
575 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
576 def grad(self,args,gz): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
577 return [None] * len(args) |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
578 |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
579 def __str__(self): |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
580 if self.encbool: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
581 string = "Enc" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
582 else: |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
583 string = "Dec" |
8d5d42274bd1
improved readability DAA_inputs_groups and scan_inputs_groups
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
711
diff
changeset
|
584 return "ScanMask" + string |
694
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
585 |
69947f4e9c0e
added a Mask creation Op and fixed some bugs
Xavier Glorot <glorotxa@iro.umontreal.ca>
parents:
686
diff
changeset
|
586 scanmaskenc=ScanMask(True) |
711
0eae6d5315b5
Fixed minor typo in comment
Olivier Delalleau <delallea@iro>
parents:
701
diff
changeset
|
587 scanmaskdec=ScanMask(False) |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
588 |
717
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
589 # TODO The classes FillMissing and MaskSelect below should probably be moved |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
590 # to another (more appropriate) file. |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
591 class FillMissing(Op): |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
592 """ |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
593 Given an input, output two elements: |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
594 - a copy of the input where missing values (NaN) are replaced by some |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
595 other value (zero by default) |
720
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
596 - a mask of the same size and type as input, where each element is zero |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
597 iff the corresponding input is missing |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
598 The 'fill_with' parameter may either be: |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
599 - a scalar: all missing values are replaced with this value |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
600 - a Numpy array: a missing value is replaced by the value in this array |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
601 at the same position (ignoring the first k dimensions if 'fill_with' |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
602 has k less dimensions than the input) |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
603 Currently, the gradient is computed as if the input value was really what |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
604 it was replaced with. It may be safer to replace the gradient w.r.t. |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
605 missing values with either zeros or missing values (?). |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
606 """ |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
607 |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
608 def __init__(self, fill_with=0): |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
609 super(Op, self).__init__() |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
610 self.fill_with = fill_with |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
611 self.fill_with_is_array = isinstance(self.fill_with, numpy.ndarray) |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
612 |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
613 def __eq__(self, other): |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
614 return (type(self) == type(other) and |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
615 self.fill_with_is_array == other.fill_with_is_array and |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
616 ((self.fill_with_is_array and |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
617 (self.fill_with == other.fill_with).all()) or |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
618 self.fill_with == other.fill_with)) |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
619 |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
620 def __hash__(self): |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
621 if self.fill_with_is_array: |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
622 fill_hash = self.fill_with.__hash__() |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
623 else: |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
624 fill_hash = hash(self.fill_with) |
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
625 return hash(type(self))^hash(self.fill_with_is_array)^fill_hash |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
626 |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
627 def make_node(self, input): |
720
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
628 return Apply(self, [input], [input.type(), input.type()]) |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
629 |
724
d42b4bcbb582
Replaced debug special code for missing values (-123456) by truly missing (NaN)
Olivier Delalleau <delallea@iro>
parents:
720
diff
changeset
|
630 def perform(self, node, (input, ), output_storage): |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
631 out = output_storage[0] |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
632 out[0] = input.copy() |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
633 out = out[0] |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
634 mask = output_storage[1] |
770
742972b6906a
scall optimisation to FillMissing.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
768
diff
changeset
|
635 |
742972b6906a
scall optimisation to FillMissing.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
768
diff
changeset
|
636 if mask[0] is None or mask[0].shape!=input.shape: |
742972b6906a
scall optimisation to FillMissing.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
768
diff
changeset
|
637 mask[0] = numpy.ones(input.shape) |
742972b6906a
scall optimisation to FillMissing.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
768
diff
changeset
|
638 |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
639 mask = mask[0] |
745
fc85ce33b518
FillMissing can now impute missing values by an array instead of a single constant
Olivier Delalleau <delallea@iro>
parents:
724
diff
changeset
|
640 if self.fill_with_is_array: |
771
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
641 #numpy.ndenumerate is slower then a loop |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
642 #so we optimise for some number of dimension frequently used |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
643 if out.ndim==1: |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
644 assert self.fill_with.ndim==1 |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
645 for i in range(out.shape[0]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
646 if numpy.isnan(out[i]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
647 out[i] = self.fill_with[i] |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
648 mask[i] = 0 |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
649 elif out.ndim==2 and self.fill_with.ndim==1: |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
650 for i in range(out.shape[0]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
651 for j in range(out.shape[1]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
652 if numpy.isnan(out[i,j]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
653 out[i,j] = self.fill_with[j] |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
654 mask[i,j] = 0 |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
655 else: |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
656 ignore_k = out.ndim - self.fill_with.ndim |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
657 assert ignore_k >= 0 |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
658 for (idx, v) in numpy.ndenumerate(out): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
659 if numpy.isnan(v): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
660 out[idx] = self.fill_with[idx[ignore_k:]] |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
661 mask[idx] = 0 |
770
742972b6906a
scall optimisation to FillMissing.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
768
diff
changeset
|
662 else: |
771
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
663 #numpy.ndenumerate is slower then a loop |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
664 #so we optimise for some number of dimension frequently used |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
665 if out.ndim==1: |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
666 for i in range(out.shape[0]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
667 if numpy.isnan(out[i]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
668 out[i] = self.fill_with |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
669 mask[i] = 0 |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
670 elif out.ndim==2: |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
671 for i in range(out.shape[0]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
672 for j in range(out.shape[1]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
673 if numpy.isnan(out[i,j]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
674 out[i,j] = self.fill_with |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
675 mask[i,j] = 0 |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
676 else: |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
677 for (idx, v) in numpy.ndenumerate(out): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
678 if numpy.isnan(out[idx]): |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
679 out[idx] = self.fill_with |
72730f38d1fb
opt of the FillMissing op. Now 80-90% faster python implementation.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
770
diff
changeset
|
680 mask[idx] = 0 |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
681 |
720
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
682 def grad(self, inputs, (out_grad, mask_grad, )): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
683 return [out_grad] |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
684 |
781
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
685 #def c(): |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
686 def c_no_compile_args(self): |
789
7a65c5b79aca
gix on gcc4.1 FillMissing, seam broken on gcc 4.3
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
786
diff
changeset
|
687 #-ffast-math and "-ffinite-math-only" SHOULD NOT BE ACTIVATED as they make isnan don't work! Idem for -funsafe-math-optimizations on gcc 4.1(on gcc 4.3 it don't break isnan) |
7a65c5b79aca
gix on gcc4.1 FillMissing, seam broken on gcc 4.3
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
786
diff
changeset
|
688 return ["-ffast-math", "-ffinite-math-only", |
7a65c5b79aca
gix on gcc4.1 FillMissing, seam broken on gcc 4.3
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
786
diff
changeset
|
689 #for gcc 4.1 we also need '-funsafe-math-optimizations', not need for gcc 4.3. TODO find a way to return the value depending of the compiler used? |
7a65c5b79aca
gix on gcc4.1 FillMissing, seam broken on gcc 4.3
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
786
diff
changeset
|
690 "-funsafe-math-optimizations" |
7a65c5b79aca
gix on gcc4.1 FillMissing, seam broken on gcc 4.3
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
786
diff
changeset
|
691 ] |
781
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
692 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
693 def c_headers(self): |
852
d15683416ebf
some fix to the c code of FillMissing. It was not compiling.
Frederic Bastien <nouiz@nouiz.org>
parents:
817
diff
changeset
|
694 return ['"Python.h"', '"numpy/noprefix.h"', '<math.h>', '<sstream>'] |
781
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
695 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
696 def c_support_code(self): |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
697 return """ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
698 using namespace std; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
699 """ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
700 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
701 def c_code(self, node, name, (input,), (value, mask), sub): |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
702 if self.fill_with==None: |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
703 print "OPTIMISATION WARNING: FillMissing don't implement this case in c. We don't support fill_with=None in c. We revert to python version", self.fill_with_is_array, node.inputs[0].ndim |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
704 return super(FillMissing,self).c_code(node, name, (input,),(value,mask), sub) |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
705 if (self.fill_with_is_array and not node.inputs[0].ndim in [1,2]) or (not node.inputs[0].ndim in [1,2,3]): |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
706 print "OPTIMISATION WARNING: FillMissing don't implement this case in c. We revert to python version", self.fill_with_is_array, node.inputs[0].ndim |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
707 return super(FillMissing,self).c_code(node, name, (input,),(value,mask), sub) |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
708 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
709 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
710 d=locals() |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
711 d.update(sub) |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
712 d["self.fill_with_is_array"] = 1 if self.fill_with_is_array else 0 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
713 d["self.fill_with"] = self.fill_with |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
714 if self.fill_with_is_array: |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
715 d["self.fill_with_length"]=str(self.fill_with.size) |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
716 s="" |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
717 for i in self.fill_with.flatten(): |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
718 s+=","+str(i) |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
719 d["self.fill_with_data"]=s[1:] |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
720 d["self.fill_with.ndim"]=str(self.fill_with.ndim) |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
721 else: |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
722 d["self.fill_with_length"]=str(1) |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
723 d["self.fill_with_data"]=str(self.fill_with) |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
724 d["self.fill_with.ndim"]=0 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
725 if node.inputs[0].type.dtype=="float32": d["type"]="float" |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
726 elif node.inputs[0].type.dtype=="float64": d["type"]="double" |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
727 else: raise Exception("Type %s not implemented "%node.inputs[0].type.dtype) |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
728 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
729 return """ |
789
7a65c5b79aca
gix on gcc4.1 FillMissing, seam broken on gcc 4.3
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
786
diff
changeset
|
730 //This space was added to for the recompilation as we changed the compiler option. |
781
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
731 int typenum; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
732 PyArrayObject* input = %(input)s, *value = %(value)s, *mask = %(mask)s; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
733 %(type)s fill_with[%(self.fill_with_length)s] = {%(self.fill_with_data)s}; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
734 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
735 if(!PyArray_Check(input)){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
736 PyErr_SetString(PyExc_ValueError, "input must be an ndarray"); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
737 %(fail)s; |
786
0eb53b967ee7
dummy change to force the compilation with the good paramter of FillMissing.c_code()
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
781
diff
changeset
|
738 |
781
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
739 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
740 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
741 typenum = PyArray_ObjectType((PyObject*)input, 0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
742 if(!value || !PyArray_SAMESHAPE(value,input)){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
743 Py_XDECREF(value); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
744 value = (PyArrayObject*) PyArray_ZEROS(input->nd, input->dimensions, typenum,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
745 %(value)s = value; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
746 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
747 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
748 if (!mask || !PyArray_SAMESHAPE(mask,input)){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
749 Py_XDECREF(mask); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
750 mask = (PyArrayObject*) PyArray_ZEROS(input->nd, input->dimensions, typenum,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
751 %(mask)s = mask; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
752 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
753 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
754 if(!PyArray_ISCONTIGUOUS(input)){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
755 cout<<"OPTIMISATION WARNING: in FillMissing, the input is not contiguous in memory, so we create a new version that is contiguous. This can be optimized by using directly the data."<<endl; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
756 input = PyArray_GETCONTIGUOUS((PyArrayObject*)input); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
757 if(!PyArray_ISCONTIGUOUS(input)){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
758 PyErr_SetString(PyExc_ValueError, "input is not continuous in memory"); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
759 %(fail)s; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
760 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
761 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
762 if(!PyArray_ISCONTIGUOUS(value)){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
763 cout<<"OPTIMISATION WARNING: in FillMissing, the value is not contiguous in memory, so we create a new version that is contiguous. This can be optimized by using directly the data."<<endl; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
764 value = PyArray_GETCONTIGUOUS((PyArrayObject*)value); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
765 if(!PyArray_ISCONTIGUOUS(value)){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
766 PyErr_SetString(PyExc_ValueError, "value is not continuous in memory"); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
767 %(fail)s; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
768 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
769 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
770 if(!PyArray_ISCONTIGUOUS(mask)){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
771 cout<<"OPTIMISATION WARNING: in FillMissing, the mask is not contiguous in memory, so we create a new version that is contiguous. This can be optimized by using directly the data."<<endl; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
772 mask = PyArray_GETCONTIGUOUS((PyArrayObject*)mask); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
773 if(!PyArray_ISCONTIGUOUS(mask)){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
774 PyErr_SetString(PyExc_ValueError, "mask is not continuous in memory"); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
775 %(fail)s; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
776 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
777 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
778 |
852
d15683416ebf
some fix to the c code of FillMissing. It was not compiling.
Frederic Bastien <nouiz@nouiz.org>
parents:
817
diff
changeset
|
779 if(input->nd!=value->nd || input->nd!=mask->nd){ |
d15683416ebf
some fix to the c code of FillMissing. It was not compiling.
Frederic Bastien <nouiz@nouiz.org>
parents:
817
diff
changeset
|
780 PyErr_Format(PyExc_ValueError, |
d15683416ebf
some fix to the c code of FillMissing. It was not compiling.
Frederic Bastien <nouiz@nouiz.org>
parents:
817
diff
changeset
|
781 "FillMissing input have %%d dims, the mask have %%d dims and the value have %%d dims. They should all be equals \\n", |
d15683416ebf
some fix to the c code of FillMissing. It was not compiling.
Frederic Bastien <nouiz@nouiz.org>
parents:
817
diff
changeset
|
782 input->nd, value->nd, mask->nd); |
d15683416ebf
some fix to the c code of FillMissing. It was not compiling.
Frederic Bastien <nouiz@nouiz.org>
parents:
817
diff
changeset
|
783 %(fail)s; |
d15683416ebf
some fix to the c code of FillMissing. It was not compiling.
Frederic Bastien <nouiz@nouiz.org>
parents:
817
diff
changeset
|
784 } |
781
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
785 #if %(self.fill_with_is_array)s |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
786 if(input->nd==1){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
787 %(type)s* value_ = (%(type)s*)(value->data); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
788 %(type)s* mask_ = (%(type)s*)(mask->data); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
789 %(type)s* input_ = (%(type)s*)(input->data); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
790 for(int i=0;i<input->dimensions[0];i++){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
791 if(isnan(input_[i])){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
792 value_[i]=fill_with[i]; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
793 mask_[i]=0; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
794 }else{ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
795 value_[i]=input_[i]; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
796 mask_[i]=1; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
797 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
798 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
799 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
800 }else if(input->nd==2 && %(self.fill_with.ndim)s==1){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
801 for(int i=0; i<input->dimensions[0];i++){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
802 %(type)s* value_ = (%(type)s*) PyArray_GETPTR2(value,i,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
803 %(type)s* mask_ = (%(type)s*) PyArray_GETPTR2(mask,i,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
804 %(type)s* input_ = (%(type)s*) PyArray_GETPTR2(input,i,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
805 for(int j=0; j<input->dimensions[1];j++){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
806 if(isnan(input_[j])){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
807 value_[j]=fill_with[j]; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
808 mask_[j]=0; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
809 }else{ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
810 value_[j]=input_[j]; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
811 mask_[j]=1; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
812 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
813 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
814 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
815 }else{//not implemented! |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
816 //SHOULD not happen as c_code should revert to the python version in that case |
852
d15683416ebf
some fix to the c code of FillMissing. It was not compiling.
Frederic Bastien <nouiz@nouiz.org>
parents:
817
diff
changeset
|
817 std::stringstream temp; |
781
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
818 temp << "In FillMissing, we try to fill with an array and the input ndim is implemented only for 1 and 2. This case is not implemented."<<endl; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
819 temp << " ndim="<<input->nd<<endl;; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
820 std::string param = temp.str(); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
821 PyErr_SetString(PyExc_ValueError, param.c_str()); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
822 %(fail)s |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
823 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
824 #else |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
825 //we fill with a scalar |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
826 if(input->nd==1){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
827 %(type)s* value_ = (%(type)s*)(value->data); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
828 %(type)s* mask_ = (%(type)s*)(mask->data); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
829 %(type)s* input_ = (%(type)s*)(input->data); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
830 for(int i=0;i<input->dimensions[0];i++){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
831 if(isnan(input_[i])){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
832 value_[i]=%(self.fill_with)s; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
833 mask_[i]=0; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
834 }else{ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
835 value_[i]=input_[i]; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
836 mask_[i]=1; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
837 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
838 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
839 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
840 }else if(input->nd==2){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
841 for(int i=0;i<input->dimensions[0];i++){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
842 %(type)s* value_ = (%(type)s*) PyArray_GETPTR2(value,i,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
843 %(type)s* mask_ = (%(type)s*) PyArray_GETPTR2(mask,i,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
844 %(type)s* input_ = (%(type)s*) PyArray_GETPTR2(input,i,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
845 for(int j=0;j<input->dimensions[1];j++){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
846 if(isnan(input_[j])){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
847 value_[j]=%(self.fill_with)s; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
848 mask_[j]=0; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
849 }else{ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
850 value_[j]=input_[j]; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
851 mask_[j]=1; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
852 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
853 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
854 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
855 }else if(input->nd==3){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
856 for(int i=0;i<input->dimensions[0];i++){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
857 for(int j=0;j<input->dimensions[1];j++){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
858 %(type)s* value_ = (%(type)s*) PyArray_GETPTR3(value,i,j,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
859 %(type)s* mask_ = (%(type)s*) PyArray_GETPTR3(mask,i,j,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
860 %(type)s* input_ = (%(type)s*) PyArray_GETPTR3(input,i,j,0); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
861 for(int k=0;k<input->dimensions[2];k++){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
862 if(isnan(input_[k])){ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
863 value_[k]=%(self.fill_with)s; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
864 mask_[k]=0; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
865 }else{ |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
866 value_[k]=input_[k]; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
867 mask_[k]=1; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
868 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
869 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
870 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
871 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
872 }else{//not implemented! |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
873 //SHOULD not happen as c_code should revert to the python version in that case |
852
d15683416ebf
some fix to the c code of FillMissing. It was not compiling.
Frederic Bastien <nouiz@nouiz.org>
parents:
817
diff
changeset
|
874 std::stringstream temp; |
781
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
875 temp << "In FillMissing, we try to fill with a constant and the input ndim is implemented only for 1, 2 and 3."; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
876 temp << " ndim="<<input->nd<<endl;; |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
877 std::string param = temp.str(); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
878 PyErr_SetString(PyExc_ValueError, param.c_str()); |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
879 %(fail)s |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
880 } |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
881 #endif |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
882 |
b6670cb57101
implemented FillMissing.c_code. It use the new c_no_compile_args to remove -ffast-math and -ffinite-math-only as they broke isnan.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
771
diff
changeset
|
883 """%d |
716
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
884 fill_missing_with_zeros = FillMissing(0) |
d7c6dadb4aa9
New class FillMissing (should probably be moved to another file later)
Olivier Delalleau <delallea@iro>
parents:
711
diff
changeset
|
885 |
720
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
886 class MaskGradient(Op): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
887 """ |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
888 Takes as input a tensor and a mask. Outputs the same tensor, but setting |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
889 to zero the gradient for all elements where the mask's value is zero. |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
890 """ |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
891 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
892 def __eq__(self, other): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
893 return type(self) == type(other) |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
894 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
895 def __hash__(self): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
896 return hash(type(self)) |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
897 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
898 def make_node(self, input, mask): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
899 return Apply(self, [input, mask], [input.type()]) |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
900 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
901 def perform(self, node, (input, mask), (output, )): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
902 output[0] = input.copy() |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
903 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
904 def grad(self, (input, mask), (out_grad, )): |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
905 return [out_grad * T.neq(mask, 0), None] |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
906 |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
907 mask_gradient = MaskGradient() |
0594cba02fa8
Some fixes to FillMissing, and added MaskGradient (another Op to move elsewhere later)
Olivier Delalleau <delallea@iro>
parents:
719
diff
changeset
|
908 |
717
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
909 class MaskSelect(Op): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
910 """ |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
911 Given an input x and a mask m (both vectors), outputs a vector that |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
912 contains all elements x[i] such that bool(m[i]) is True. |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
913 """ |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
914 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
915 def __eq__(self, other): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
916 return type(self) == type(other) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
917 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
918 def __hash__(self): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
919 return hash(type(self)) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
920 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
921 def make_node(self, input, mask): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
922 return Apply(self, [input, mask], [input.type()]) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
923 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
924 def perform(self, node, (input, mask), (output, )): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
925 select = [] |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
926 for (i, m) in enumerate(mask): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
927 if bool(m): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
928 select.append(i) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
929 output[0] = numpy.zeros(len(select), dtype = input.dtype) |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
930 out = output[0] |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
931 for (i, j) in enumerate(select): |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
932 out[i] = input[j] |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
933 |
bf29e201515f
New class MaskSelect (should probably be moved to another file)
Olivier Delalleau <delallea@iro>
parents:
716
diff
changeset
|
934 mask_select = MaskSelect() |