import numpy
import theano
from theano import tensor as T
from theano.gof import Op
from theano.gof import Apply
from theano import scalar as scal

# These Ops allows us to deal with static groups of possibly missing inputs efficiently in the dense DAA framework
# (for exemple with multimodal data with sometimes entire modality missing).
# The inputs will be represented with an index list and a theano.generic variable (which will be a list of matrices
# (numpy array), each element will correspond to an available modality and the index list will indicate the weights
# associated to it).
# Exemple of index list: [1, 0, -3]
#    *the 1 says that the first element of the input list will refer to the first element of the weights_list
#        (auxiliary target as input)
#                                if inputslist[i]>0 it refers to Weightslist[indexlist[i]-1]
#    *the 0 means that the second element of the input list will not be encoded neither decoded (it is remplaced by zeros)
#        this is not efficient, so in this case it is better to give: [1,-3] and [inputslist[0],inputslist[2]]
#        but it allows us to deal with empty lists: give indexlist = numpy.asarray([.0])
#        and inputlist=numpy.zeros((batchsize,1))
#    *when an index is negative it means that the input will not be used for encoding but we will still reconstruct it
#        (auxiliary target as output)
#                                if inputslist[i]<0 it refers to Weightslist[-indexlist[i]-1]
# An entire batch should have the same available inputs configuration.
# Dense DAA Exemple:----------------------------------------------------------------------------
#from theano.tensor.nnet import sigmoid
#nb_modality = 4
#wenc = [T.dmatrix('wenc%s'%i) for i in range(nb_modality)]
#wdec = [T.dmatrix('wdec%s'%i) for i in range(nb_modality)]
#benc = T.dvector('benc')
#bdec = [T.dvector('bdec%s'%i) for i in range(nb_modality)]
#vectin = T.ivector('vectin')
#inputpart = theano.generic('inputpart')
#noise_bit = T.dscalar('noise_bit')
#noise_group = T.dscalar('noise_group')
#[vectin2,inputpart2] = scannoise(vectin,inputpart,noise_bit,noise_group)
#hid = scandotenc(vectin2, inputpart2, wenc)
#acthid = sigmoid(hid + benc)
#dec = sigmoid(scanbiasdec(vectin2,inputpart2,bdec) + scandotdec(vectin2, inputpart2,acthid,wdec))
#cost = T.sum(T.sum(T.sqr( scaninput(vectin,inputpart) - rec ),1),0)

# Checking inputs in make_node methods----------------------
def Checkidx_list(idx_list):
    idx_list = T.as_tensor_variable(idx_list)
    nidx = idx_list.type.ndim
    if nidx != 1: raise TypeError('not vector', idx_list)
    return idx_list

def Checkhidd(hidd):
    hidd = T.as_tensor_variable(hidd)
    nhidd = hidd.type.ndim
    if nhidd not in (1,2): raise TypeError('not matrix or vector', hidd)
    return hidd

def Checkweights_list(weights_list):
    weights_list = map(T.as_tensor_variable, weights_list)
    for i in range(len(weights_list)):
        nweights = weights_list[i].type.ndim
        if nweights not in (1,2): raise TypeError('not matrix or vector', weights_list[i])
    return weights_list

def Checkbias_list(bias_list):
    bias_list = map(T.as_tensor_variable, bias_list)
    for i in range(len(bias_list)):
        nbias = bias_list[i].type.ndim
        if nbias != 1: raise TypeError('not vector', bias_list[i])
    return bias_list

# Encoding scan dot product------------------------------------
class ScanDotEnc(Op):
    """This Op takes an index list (as tensor.ivector), a list of matrices representing
    the available inputs (as theano.generic), and all the encoding weights tensor.dmatrix of the model. It will select the
    weights corresponding to the inputs (according to index list) and compute only the necessary dot products"""
    def __init__(self):
        #Create Theano methods to do the dot products with blas or at least in C.
        inputs = T.dmatrix('input')
        weights = T.dmatrix('weights')
        self.M.hid = T.dmatrix('hid')
        self.M.resultin = self.M.hid +,weights)
        result =,weights)
        self.M.dotin = theano.Method([inputs,weights],None,{self.M.hid : self.M.resultin}) = theano.Method([inputs,weights],result)
        self.m = self.M.make()
    def make_node(self, idx_list, inputs_list, weights_list):
        idx_list = Checkidx_list(idx_list)
        weights_list = Checkweights_list(weights_list)
        return Apply(self, [idx_list] + [inputs_list] + weights_list, [T.dmatrix()])
    def perform(self, node, args, (hid,)):
        idx_list = args[0]
        hidcalc = False
        batchsize = (args[1][0].shape)[0]
        n_hid = (args[2].shape)[1]
        if len(idx_list) != len(args[1]) :
            raise NotImplementedError('size of index different of inputs list size',idx_list)
        if max(idx_list) >= (len(args)-2)+1 :
            raise NotImplementedError('index superior to weight list length',idx_list)
        for a in args[1]:
            if (a.shape)[0] != batchsize:
                raise NotImplementedError('different batchsize in the inputs list',a.shape)
        for a in args[2:]:
            if (a.shape)[1] != n_hid:
                raise NotImplementedError('different length of hidden in the weights list',a.shape)
        for i in range(len(idx_list)):
            if idx_list[i]>0:
                if hidcalc:
                    self.m.hid =[1][i],args[2+int(idx_list[i]-1)])
                    hidcalc = True
        if not hidcalc:
            hid[0] = numpy.zeros([batchsize,n_hid])
            hid[0] = self.m.hid
    def grad(self, args, gz):
        gradi = ScanDotEncGrad()(args,gz)
        if type(gradi) != list:
            return [None, None] + [gradi]
            return [None, None] + gradi
    def __hash__(self):
        return hash(ScanDotEnc)^58994
    def __str__(self):
        return "ScanDotEnc"


class ScanDotEncGrad(Op):
    """This Op computes the gradient wrt the weights for ScanDotEnc"""
    def __init__(self):
        #Create Theano methods to do the dot products with blas or at least in C.
        input1 = T.dmatrix('input1')
        self.M.g_out = T.dmatrix('g_out')
        result = T.dmatrix('result')
        self.M.resultin = result +,self.M.g_out)
        self.M.result =,self.M.g_out)
        self.M.dotin = theano.Method([input1,result],self.M.resultin) = theano.Method([input1],self.M.result)
        self.m = self.M.make()
    def make_node(self, args, g_out):
        idx_list = Checkidx_list(args[0])
        weights_list = Checkweights_list(args[2:])
        return Apply(self, args + g_out, [T.dmatrix() for i in xrange(2,len(args))])
    def perform(self, node, args, z):
        idx_list = args[0]
        self.m.g_out = args[-1]
        batchsize = (args[1][0].shape)[0]
        n_hid = (args[2].shape)[1]
        if len(idx_list) != len(args[1]) :
            raise NotImplementedError('size of index different of inputs list size',idx_list)
        if max(idx_list) >= (len(args)-3)+1 :
            raise NotImplementedError('index superior to weight list length',idx_list)
        for a in args[1]:
            if (a.shape)[0] != batchsize:
                raise NotImplementedError('different batchsize in the inputs list',a.shape)
        for a in args[2:-1]:
            if (a.shape)[1] != n_hid:
                raise NotImplementedError('different length of hidden in the weights list',a.shape)
        zcalc = [False for i in range(len(args)-3)]
        for i in range(len(idx_list)):
            if idx_list[i]>0:
                if zcalc[int(idx_list[i]-1)]:
                    z[int(idx_list[i]-1)][0] = self.m.dotin(args[1][i],z[int(idx_list[i]-1)][0])
                    z[int(idx_list[i]-1)][0] =[1][i])
                    zcalc[int(idx_list[i]-1)] = True
        for i in range(len(args)-3):
            if not zcalc[i]:
                shp = args[2+i].shape
                z[i][0] = numpy.zeros(shp)
    def __hash__(self):
        return hash(ScanDotEncGrad)^15684
    def __str__(self):
        return "ScanDotEncGrad"

# Decoding scan dot product------------------------------------
class ScanDotDec(Op):
    """This Op takes an index list (as tensor.ivector), a list of matrices representing
    the available inputs (as theano.generic), the hidden layer of the DAA (theano.dmatrix)
    and all the decoding weights tensor.dmatrix of the model. It will select the
    weights corresponding to the available inputs (according to index list) and compute
    only the necessary dot products. The outputs will be concatenated and will represent
    the reconstruction of the different modality in the same order than the index list"""
    def __init__(self):
        #Create Theano methods to do the dot products with blas or at least in C.
        weights = T.dmatrix('weights')
        self.M.hid = T.dmatrix('hid')
        oldval = T.dmatrix('oldval')
        resultin = oldval +,weights)
        result =,weights)
        self.M.dotin = theano.Method([weights,oldval],resultin) = theano.Method([weights],result)
        self.m = self.M.make()
    def make_node(self, idx_list, input_list, hidd, weights_list):
        idx_list = Checkidx_list(idx_list)
        hidd = Checkhidd(hidd)
        weights_list = Checkweights_list(weights_list)
        return Apply(self, [idx_list] + [input_list] +[hidd] + weights_list,[T.dmatrix()])
    def perform(self, node, args, (z,)):
        idx_list = abs(args[0])
        self.m.hid = args[2]
        batchsize = (self.m.hid.shape)[0]
        n_hid = self.m.hid.shape[1]
        if max(idx_list) >= len(args)-3+1 :
            raise NotImplementedError('index superior to weight list length',idx_list)
        if len(idx_list) != len(args[1]) :
            raise NotImplementedError('size of index different of inputs list size',idx_list)
        for a in args[3:]:
            if (a.shape)[0] != n_hid:
                raise NotImplementedError('different length of hidden in the weights list',a.shape)
        zcalc = [False for i in idx_list]
        z[0] = [None for i in idx_list]
        for i in range(len(idx_list)):
            if idx_list[i]>0:
                if zcalc[i]:
                    z[0][i] = self.m.dotin(args[3+int(idx_list[i]-1)],z[0][i])
                    z[0][i] =[3+int(idx_list[i]-1)])
                    zcalc[i] = True
        for i in range(len(idx_list)):
            if not zcalc[i]:
                shp = args[1][int(idx_list[i]-1)].shape
                z[0][i] = numpy.zeros((batchsize,shp[1]))
        z[0] = numpy.concatenate(z[0],1)
    def grad(self, args, gz):
        gradi = ScanDotDecGrad()(args,gz)
        if type(gradi) != list:
            return [None, None] + [gradi]
            return [None, None] + gradi
    def __hash__(self):
        return hash(ScanDotDec)^73568
    def __str__(self):
        return "ScanDotDec"


class ScanDotDecGrad(Op):
    """This Op computes the gradient wrt the weights for ScanDotDec"""
    def __init__(self):
        gout = T.dmatrix('gout')
        self.M.hidt = T.dmatrix('hid')
        oldval = T.dmatrix('oldval')
        resultin1 = oldval +,gout)
        result1 =,gout)
        weights = T.dmatrix('weights')
        weights2 = T.transpose(weights)
        resultin2 = oldval +,weights2)
        result2 =,weights2)
        self.M.dotin1 = theano.Method([gout,oldval],resultin1)
        self.M.dot1 = theano.Method([gout],result1)
        self.M.dotin2 = theano.Method([gout,weights,oldval],resultin2)
        self.M.dot2 = theano.Method([gout,weights],result2)
        self.m = self.M.make()
    def make_node(self, args, g_out):
        idx_list = Checkidx_list(args[0])
        hidd = Checkhidd(args[2])
        weights_list = Checkweights_list(args[3:])
        return Apply(self, args + g_out, [T.dmatrix() for i in xrange(2,len(args))])
    def perform(self, node, args, z):
        idx_list = abs(args[0])
        self.m.hidt = args[2].T
        batchsize = (self.m.hidt.shape)[1]
        n_hid = self.m.hidt.shape[0]
        if max(idx_list) >= len(args)-4+1 :
            raise NotImplementedError('index superior to weight list length',idx_list)
        if len(idx_list) != len(args[1]) :
            raise NotImplementedError('size of index different of inputs list size',idx_list)
        for a in args[3:-1]:
            if a.shape[0] != n_hid:
                raise NotImplementedError('different length of hidden in the weights list',a.shape)
        for i in range(len(idx_list)):
            if idx_list[i] == 0:
                zidx[i+1] = (args[1][i].shape)[1]
                zidx[i+1] = (args[3+idx_list[i]-1].shape)[1]
        hidcalc = False
        zcalc = [False for i in range((len(args)-4))]
        for i in range(len(idx_list)):
            if idx_list[i]>0:
                if zcalc[int(idx_list[i])-1]:
                    z[int(idx_list[i])][0] = self.m.dotin1(args[-1][:,zidx[i]:zidx[i+1]],z[int(idx_list[i])][0])
                    z[int(idx_list[i])][0] = self.m.dot1(args[-1][:,zidx[i]:zidx[i+1]])
                    zcalc[int(idx_list[i])-1] = True
                if hidcalc:
                    z[0][0] = self.m.dotin2(args[-1][:,zidx[i]:zidx[i+1]],args[3+int(idx_list[i]-1)],z[0][0])
                    z[0][0] = self.m.dot2(args[-1][:,zidx[i]:zidx[i+1]],args[3+int(idx_list[i]-1)])
                    hidcalc = True
        if not hidcalc:
            z[0][0] = numpy.zeros((self.m.hidt.shape[1],self.m.hidt.shape[0]))
        for i in range((len(args)-4)):
            if not zcalc[i]:
                shp = args[3+i].shape
                z[i+1][0] = numpy.zeros(shp)
    def __hash__(self):
        return hash(ScanDotDecGrad)^87445
    def __str__(self):
        return "ScanDotDecGrad"

# DAA input noise------------------------------------
class ScanNoise(Op):
    """This Op takes an index list (as tensor.ivector), a list of matrices representing
    the available inputs (as theano.generic), a probability of individual bit masking and
    a probability of modality masking. It will return the inputs list with randoms zeros entry
    and the index list with some positive values changed to negative values (groups masking)"""
    def __init__(self, seed = 1):
        self.M.rand = T.RandomStreams(seed)
        self.seed = seed
        mat = T.matrix('mat')
        noise_level_bit = T.dscalar('noise_level_bit')
        noise_level_group = T.dscalar('noise_level_group')
        self.M.out1 = self.M.rand.binomial(T.shape(mat), 1, 1 - noise_level_bit) * mat
        self.M.out2 = self.M.rand.binomial((1,1), 1, 1 - noise_level_group)
        self.M.noisify_bit = theano.Method([mat,noise_level_bit],self.M.out1)
        self.M.noisify_group_bool = theano.Method([noise_level_group],self.M.out2)
        self.R = self.M.make()
    def make_node(self, idx_list, inputs_list, noise_level_bit, noise_level_group):
        idx_list = Checkidx_list(idx_list)
        return Apply(self, [idx_list] + [inputs_list] + [noise_level_bit] + [noise_level_group],\
                [T.ivector(), theano.generic()])
    def perform(self, node, (idx_list,inputs_list,noise_level_bit,noise_level_group), (y,z)):
        if len(idx_list) != len(inputs_list) :
            raise NotImplementedError('size of index different of inputs list size',idx_list)
        y[0] = numpy.asarray([-i if (i>0 and not(self.R.noisify_group_bool(noise_level_group))) else i for i in idx_list])
        z[0] = [(self.R.noisify_bit(inputs_list[i],noise_level_bit) if y[0][i]>0 else numpy.zeros((inputs_list[i].shape)))\
                for i in range(len(inputs_list))]
    def grad(self,args,gz):
        return [None,None,None,None]
    def __hash__(self):
        return hash(ScanNoise)^hash(self.seed)^hash(self.R.rand)^12254
    def __str__(self):
        return "ScanNoise"


# Total input matrix construction------------------------------------
class ScanInputs(Op):
    """This Op takes an index list (as tensor.ivector) and a list of matrices representing
    the available inputs (as theano.generic). It will construct the appropriate tensor.dmatrix
    to compare to the reconstruction obtained with ScanDotDec"""
    def make_node(self, idx_list, inputs_list):
        idx_list = Checkidx_list(idx_list)
        return Apply(self, [idx_list] + [inputs_list],[T.dmatrix()])
    def perform(self, node, (idx_list, inputs_list), (z,)):
        if len(idx_list) != len(inputs_list):
            raise NotImplementedError('size of index different of inputs list size',idx_list)
        for i in range(len(idx_list)):
            if idx_list[i] == 0:
                inputs_list[i] = 0 * inputs_list[i]
        z[0] = numpy.concatenate(inputs_list,1)
    def grad(self,args,gz):
        return [None,None]
    def __hash__(self):
        return hash(ScanInputs)^75902
    def __str__(self):
        return "ScanInputs"


# Decoding bias vector construction------------------------------------
class ScanBiasDec(Op):
    """This Op takes an index list (as tensor.ivector), a list of matrices representing
    the available inputs (as theano.generic) and the decoding bias tensor.dvector.
    It will construct the appropriate bias tensor.dvector
    to add to the reconstruction obtained with ScanDotDec"""
    def make_node(self, idx_list, input_list, bias_list):
        idx_list = Checkidx_list(idx_list)
        bias_list = Checkbias_list(bias_list)
        return Apply(self, [idx_list] + [input_list] + bias_list, [T.dvector()])
    def perform(self, node, args, (z,)):
        idx_list = abs(args[0])
        if max(idx_list) >= (len(args)-2)+1 :
            raise NotImplementedError('index superior to bias list length',idx_list)
        if len(idx_list) != len(args[1]) :
            raise NotImplementedError('size of index different of inputs list size',idx_list)
        z[0] = [args[idx_list[i]+1] if idx_list[i] != 0 else numpy.zeros(args[1][i].shape[1]) \
                for i in range(len(idx_list))]
        z[0] = numpy.concatenate(z[0],1)
    def __hash__(self):
        return hash(ScanBiasDec)^60056
    def grad(self,args,gz):
        gradi = ScanBiasDecGrad()(args,gz)
        if type(gradi) != list:
            return [None, None] + [gradi]
            return [None, None] + gradi
    def __str__(self):
        return "ScanBiasDec"


class ScanBiasDecGrad(Op):
    """This Op computes the gradient wrt the bias for ScanBiasDec"""
    def make_node(self, args, g_out):
        idx_list = Checkidx_list(args[0])
        bias_list = Checkbias_list(args[2:])
        return Apply(self, args + g_out, [T.dvector() for i in range(len(args)-2)])
    def perform(self, node, args, z):
        idx_list = abs(args[0])
        if max(idx_list) >= (len(args)-3)+1 :
            raise NotImplementedError('index superior to bias list length',idx_list)
        if len(idx_list) != len(args[1]) :
            raise NotImplementedError('size of index different of inputs list size',idx_list)
        for i in range(len(idx_list)):
            if idx_list[i] == 0:
                zidx[i+1] = (args[1][i].shape)[1]
                zidx[i+1] = (args[2+idx_list[i]-1].size)
        zcalc = [False for i in range((len(args)-3))]
        for i in range(len(idx_list)):
            if idx_list[i]>0:
                if zcalc[int(idx_list[i])-1]:
                    z[int(idx_list[i])-1][0] += args[-1][zidx[i]:zidx[i+1]]
                    z[int(idx_list[i])-1][0] = args[-1][zidx[i]:zidx[i+1]]
                    zcalc[int(idx_list[i])-1] = True
        for i in range((len(args)-3)):
            if not zcalc[i]:
                shp = args[2+i].size
                z[i][0] = numpy.zeros(shp)
    def __hash__(self):
        return hash(ScanBiasDecGrad)^41256
    def __str__(self):
        return "ScanBiasDecGrad"

# Mask construction------------------------------------
class ScanMask(Op):
    """This Op takes an index list (as tensor.ivector) and a list of weigths.
    It will construct a list of T.iscalar representing the Mask
    to do the correct regularisation on the weigths"""
    def __init__(self,encbool=True):
        self.encbool = encbool
    def make_node(self, idx_list, weights_list):
        idx_list = Checkidx_list(idx_list)
        weights_list = Checkweights_list(weights_list)
        return Apply(self, [idx_list] + weights_list, [T.iscalar() for i in range(len(weights_list))])
    def perform(self, node, args, z):
        if self.encbool:
            idx_list = args[0]
            dim = 1
            idx_list = abs(args[0])
            dim = 0
        n_hid = args[1].shape[dim]

        if max(idx_list) >= (len(args)-1)+1 :
            raise NotImplementedError('index superior to weights list length',idx_listdec)
        for a in args[1:]:
            if a.shape[dim] != n_hid:
                raise NotImplementedError('different length of hidden in the encoding weights list',a.shape)
        for i in range(len(args[1:])):
            z[i][0] = numpy.asarray((idx_list == i+1).sum(),dtype='int32')
    def __hash__(self):
        return hash(ScanMask)^hash(self.encbool)^11447
    def grad(self,args,gz):
        return [None] * len(args)
    def __str__(self):
        if self.encbool:
            string = "Enc"
            string = "Dec"
        return "ScanMask" + string


# TODO The classes FillMissing and MaskSelect below should probably be moved
# to another (more appropriate) file.
class FillMissing(Op):
    Given an input, output two elements:
        - a copy of the input where missing values (NaN) are replaced by some
        other value (zero by default)
        - a mask of the same size and type as input, where each element is zero
        iff the corresponding input is missing
    The 'fill_with' parameter may either be:
        - a scalar: all missing values are replaced with this value
        - a Numpy array: a missing value is replaced by the value in this array
        at the same position (ignoring the first k dimensions if 'fill_with'
        has k less dimensions than the input)
    Currently, the gradient is computed as if the input value was really what
    it was replaced with. It may be safer to replace the gradient w.r.t.
    missing values with either zeros or missing values (?).

    def __init__(self, fill_with=0):
        super(Op, self).__init__()
        self.fill_with = fill_with
        self.fill_with_is_array = isinstance(self.fill_with, numpy.ndarray)

    def __eq__(self, other):
        return (type(self) == type(other) and
                self.fill_with_is_array == other.fill_with_is_array and
                ((self.fill_with_is_array and 
                    (self.fill_with == other.fill_with).all()) or
                    self.fill_with == other.fill_with))

    def __hash__(self):
        if self.fill_with_is_array:
            fill_hash = self.fill_with.__hash__()
            fill_hash = hash(self.fill_with)
        return hash(type(self))^hash(self.fill_with_is_array)^fill_hash
    def make_node(self, input):
        return Apply(self, [input], [input.type(), input.type()])

    def perform(self, node, (input, ), output_storage):
        out = output_storage[0]
        out[0] = input.copy()
        out = out[0]
        mask = output_storage[1]
        if mask[0] is None or mask[0].shape!=input.shape:
            mask[0] = numpy.ones(input.shape)

        mask = mask[0]
        if self.fill_with_is_array:
            #numpy.ndenumerate is slower then a loop
            #so we optimise for some number of dimension frequently used
            if out.ndim==1:
                assert self.fill_with.ndim==1
                for i in range(out.shape[0]):
                    if numpy.isnan(out[i]):
                        out[i] = self.fill_with[i]
                        mask[i] = 0
            elif out.ndim==2 and self.fill_with.ndim==1:
                for i in range(out.shape[0]):
                    for j in range(out.shape[1]):
                        if numpy.isnan(out[i,j]):
                            out[i,j] = self.fill_with[j]
                            mask[i,j] = 0
                ignore_k = out.ndim - self.fill_with.ndim
                assert ignore_k >= 0
                for (idx, v) in numpy.ndenumerate(out):
                    if numpy.isnan(v):
                        out[idx] = self.fill_with[idx[ignore_k:]]
                        mask[idx] = 0
            #numpy.ndenumerate is slower then a loop
            #so we optimise for some number of dimension frequently used
            if out.ndim==1:
                for i in range(out.shape[0]):
                    if numpy.isnan(out[i]):
                        out[i] = self.fill_with
                        mask[i] = 0
            elif out.ndim==2:
                for i in range(out.shape[0]):
                    for j in range(out.shape[1]):
                        if numpy.isnan(out[i,j]):
                            out[i,j] = self.fill_with
                            mask[i,j] = 0
                for (idx, v) in numpy.ndenumerate(out):
                    if numpy.isnan(out[idx]):
                        out[idx] = self.fill_with
                        mask[idx] = 0

    def grad(self, inputs, (out_grad, mask_grad, )):
        return [out_grad]

fill_missing_with_zeros = FillMissing(0)

class MaskGradient(Op):
    Takes as input a tensor and a mask. Outputs the same tensor, but setting
    to zero the gradient for all elements where the mask's value is zero.

    def __eq__(self, other):
        return type(self) == type(other)

	def __hash__(self):
		return hash(type(self))
    def make_node(self, input, mask):
        return Apply(self, [input, mask], [input.type()])

    def perform(self, node, (input, mask), (output, )):
        output[0] = input.copy()

    def grad(self, (input, mask), (out_grad, )):
        return [out_grad * T.neq(mask, 0), None]

mask_gradient = MaskGradient()

class MaskSelect(Op):
    Given an input x and a mask m (both vectors), outputs a vector that
    contains all elements x[i] such that bool(m[i]) is True.

    def __eq__(self, other):
        return type(self) == type(other)

	def __hash__(self):
		return hash(type(self))
    def make_node(self, input, mask):
        return Apply(self, [input, mask], [input.type()])

    def perform(self, node, (input, mask), (output, )):
        select = []
        for (i, m) in enumerate(mask):
            if bool(m):
        output[0] = numpy.zeros(len(select), dtype = input.dtype)
        out = output[0]
        for (i, j) in enumerate(select):
            out[i] = input[j]

mask_select = MaskSelect()