view pylearn/algorithms/sandbox/DAA_inputs_groups.py @ 809:a66bef83e1fd

Changes in cost function, sum over quadratic and KL instead of cross entropy for global update for DAA inputs groups
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Wed, 12 Aug 2009 18:29:18 -0400
parents 316817114b15
children 3a4bc4a0dbf4
line wrap: on
line source

import numpy
import theano
import copy

from theano import tensor as T
from theano.compile import module

from pylearn.sandbox.scan_inputs_groups import scaninputs, scandotdec, scandotenc, scannoise, scanbiasdec, \
        scanmaskenc,scanmaskdec, FillMissing, mask_gradient

from pylearn.algorithms.logistic_regression import LogRegN
import pylearn.algorithms.cost

from pylearn.io import filetensor
import os

# saving loading utils--------------------------------------------
def save_mat(fname, mat, save_dir=''):
    assert isinstance(mat, numpy.ndarray)
    print 'save ndarray to file: ', save_dir + fname
    file_handle = open(os.path.join(save_dir, fname), 'w')
    filetensor.write(file_handle, mat)
    file_handle.close()

def load_mat(fname, save_dir=''):
    print 'loading ndarray from file: ', save_dir + fname
    file_handle = open(os.path.join(save_dir,fname), 'r')
    rval = filetensor.read(file_handle)
    file_handle.close()
    return rval

# Weight initialisation utils--------------------------------------

# time consuming but just a test (not conclusive)
def orthogonalinit(W,axis=1):
    nb = W.shape[axis]
    bn = W.shape[0] if axis is 1 else W.shape[1]
    if axis == 0:
        W=W.T
    Worto = copy.copy(W)
    offset=0
    tmp=[]
    for i in range(nb):
        if i==bn:
            offset=offset+bn
        if i-offset != 0:
            for j in xrange(offset,i):
                orthoproj = (Worto[:,i]*Worto[:,j]).sum()*Worto[:,j]/(Worto[:,j]*Worto[:,j]).sum()
                orthoproj.shape=(bn,1)
                Worto[:,i:i+1] = Worto[:,i:i+1] - orthoproj
        Worto[:,i:i+1] = Worto[:,i:i+1] / \
                    numpy.sqrt((Worto[:,i:i+1]*Worto[:,i:i+1]).sum(0)) * numpy.sqrt((W[:,i:i+1]*W[:,i:i+1]).sum(0))
    return Worto if axis == 1 else Worto.T

# @todo
def PCAinit(data,nhid):
    pass

#-----------------------------------------------------------------

# Initialize containers:
class CreateContainer:
    pass

# regularisation utils:-------------------------------------------
def lnorm(param, type='l2'):
    if type == 'l1':
        return T.sum(T.abs_(param))
    if type == 'l2':
        return T.sum(param*param)
    raise NotImplementedError('Only l1 and l2 regularization are currently implemented')

def get_reg_cost(params, type):
    rcost = 0
    for param in params:
        rcost += lnorm(param, type)
    return rcost

# activations utils:----------------------------------------------
def sigmoid_act(x):
    return theano.tensor.nnet.sigmoid(x)

#tanh is scaled by 2 to have the same gradient than sigmoid [sigmoid(x)=(tanh(x/2.0)+1)/2.0]
def tanh_act(x):
    return theano.tensor.tanh(x/2.0)

# costs utils:---------------------------------------------------
# in order to fix numerical instability of the cost and gradient calculation for the cross entropy we calculate it
# with the following functions direclty from the activation:
# XS is used to get back the KL divergence, important for doing global updates

def sigmoid_cross_entropy(target, output_act, mean_axis, sum_axis):
    XE = target * (- T.log(1 + T.exp(-output_act))) + (1 - target) * (- T.log(1 + T.exp(output_act)))
    XS = T.xlogx.xlogx(target) + T.xlogx.xlogx(1-target)
    return -T.mean(T.sum(XE-XS, axis=sum_axis),axis=mean_axis)

def tanh_cross_entropy(target, output_act, mean_axis, sum_axis):
    XE = (target+1)/2.0 * (- T.log(1 + T.exp(- output_act))) + \
            (1 - (target+1)/2.0) * (- T.log(1 + T.exp(output_act)))
    XS = T.xlogx.xlogx((target+1)/2.0) + T.xlogx.xlogx(1-(target+1)/2.0)
    return -T.mean(T.sum(XE-XS, axis=sum_axis),axis=mean_axis)

def cross_entropy(target, output_act, act, mean_axis=0, sum_axis=1):
    if act == 'sigmoid_act':
        return sigmoid_cross_entropy(target, output_act, mean_axis, sum_axis)
    if act == 'tanh_act':
        return tanh_cross_entropy(target, output_act, mean_axis, sum_axis)
    assert False

def quadratic(target, output, act, mean_axis = 0):
    return T.sum(pylearn.algorithms.cost.quadratic(target, output, mean_axis))

# DAAig module----------------------------------------------------------------
class DAAig(module.Module):
    """De-noising Auto-encoder with inputs groups and missing values
    """
    
    def __init__(self, input = None, auxinput = None,
                in_size=None, auxin_size= None, n_hid=1,
                regularize = False, tie_weights = False, tie_weights_aux = None, hid_fn = 'tanh_act',
                rec_fn = 'tanh_act',reconstruction_cost_function='cross_entropy',
                interface = True, ignore_missing=None, reconstruct_missing=False,
                corruption_pattern=None, **init):
        """
        :param input: WRITEME
        :param auxinput: WRITEME
        :param in_size: WRITEME
        :param auxin_size: WRITEME
        :param n_hid: WRITEME
        :param regularize: WRITEME
        :param tie_weights: WRITEME
        :param hid_fn: WRITEME
        :param rec_fn: WRITEME
        :param reconstruction_cost_function: WRITEME
        :param scale_cost: WRITEME
        :param interface: WRITEME
        :param ignore_missing: if not None, the input will be scanned in order
            to detect missing values, and these values will be replaced. Also,
            the reconstruction cost's gradient will be computed only on non
            missing components. The value of this parameter indicates how to
            replace missing values:
                - some numpy.ndarray: value of this array at the same index
                - a constant: this same value everywhere
            If None, the presence of missing values may cause crashes or other
            weird and unexpected behavior.
            Please note that this option only affects the permanent input, not
            auxilary ones (that should never contain missing values). In fact,
            in the current implementation, auxiliary inputs cannot be used when
            this option is True.
        :param corruption_pattern: if not None, may specify a particular way to
            corrupt the input with missing values. Valid choices are:
            - 'by_pair': consider that features are given as pairs, and corrupt
            (or not) the whole pair instead of considering them independently.
            Elements in a pair are not consecutive, instead they are assumed to
            be at distance (total number of features / 2) of each other.
        :param reconstruct_missing: if True, then the reconstruction cost on
        missing inputs will be backpropagated. Otherwise, it will not.
        :todo: Default noise level for all daa levels
        """
        super(DAAig, self).__init__()
        self.random = T.RandomStreams()
        
        # MODEL CONFIGURATION
        self.in_size = in_size
        self.auxin_size = auxin_size
        self.n_hid = n_hid
        self.regularize = regularize
        self.tie_weights = tie_weights
        self.tie_weights_aux = tie_weights_aux if tie_weights_aux is not None else tie_weights
        self.interface = interface
        self.ignore_missing = ignore_missing
        self.reconstruct_missing = reconstruct_missing
        self.corruption_pattern = corruption_pattern
        
        assert hid_fn in ('sigmoid_act','tanh_act')
        self.hid_fn = eval(hid_fn)
        
        assert rec_fn in ('sigmoid_act','tanh_act')
        self.rec_fn = eval(rec_fn)
        self.rec_name = rec_fn
        
        assert reconstruction_cost_function in ('cross_entropy','quadratic')
        self.reconstruction_cost_function = eval(reconstruction_cost_function)
        self.reconstruction_cost_function_name = reconstruction_cost_function
        
        print '\t\t**** DAAig.__init__ ****'
        print '\t\tinput = ', input
        print '\t\tauxinput = ', auxinput
        print '\t\tin_size = ', self.in_size
        print '\t\tauxin_size = ', self.auxin_size
        print '\t\tn_hid = ', self.n_hid
        print '\t\tregularize = ', self.regularize
        print '\t\ttie_weights = ', self.tie_weights
        print '\t\ttie_weights_aux = ', self.tie_weights_aux
        print '\t\thid_fn = ', hid_fn
        print '\t\trec_fn = ', rec_fn
        print '\t\treconstruction_cost_function = ', reconstruction_cost_function
        
        ### DECLARE MODEL VARIABLES and default
        self.input = input
        if self.ignore_missing is not None and self.input is not None:
            no_missing = FillMissing(self.ignore_missing)(self.input)
            self.input = no_missing[0]  # With missing values replaced.
            self.input_missing_mask = no_missing[1] # Missingness pattern.
        else:
            self.input_missing_mask = None
        
        self.auxinput = auxinput
        self.idx_list = T.ivector('idx_list') if self.auxinput is not None else None
        
        self.noisy_input, self.noisy_idx_list, self.noisy_auxinput = None , None, None 
        
        #parameters
        self.benc = T.dvector('benc')
        if self.input is not None:
            self.wenc = T.dmatrix('wenc')
            self.wdec = self.wenc.T if tie_weights else T.dmatrix('wdec')
            self.bdec = T.dvector('bdec')
        
        if self.auxinput is not None:
            self.wauxenc = [T.dmatrix('wauxenc%s'%i) for i in range(len(auxin_size))]
            self.wauxdec =[ self.wauxenc[i].T if tie_weights_aux else T.dmatrix('wauxdec%s'%i) for i in\
                    range(len(auxin_size))]
            self.bauxdec = [T.dvector('bauxdec%s'%i) for i in range(len(auxin_size))]
        
        #hyper-parameters
        if self.interface:
            self.lr = T.scalar('lr')
        self.noise_level = T.scalar('noise_level')
        self.noise_level_group = T.scalar('noise_level_group')
        self.scale_cost_in = T.scalar('scale_cost_in')
        self.scale_cost_aux = T.scalar('scale_cost_aux')
        
        # leave the chance for subclasses to initialize (example convolutionnal to implement)
        if self.__class__ == DAAig:
            self.init_behavioural()
        print '\t\t**** end DAAig.__init__ ****'
    
    ### BEHAVIOURAL MODEL
    def init_behavioural(self):
        if self.input is not None:
            self.noisy_input = self.corrupt_input()
        if self.auxinput is not None:
            self.noisy_idx_list , self.noisy_auxinput = \
                    scannoise(self.idx_list, self.auxinput,self.noise_level, self.noise_level_group)
        
        self.noise = CreateContainer()
        self.clean = CreateContainer()
        
        self.define_behavioural(self.clean, self.input, self.idx_list, self.auxinput)
        self.define_behavioural(self.noise, self.noisy_input, self.noisy_idx_list, self.noisy_auxinput)
        
        self.define_regularization()  # call before cost
        self.define_cost(self.noise)  # the cost is only needed for the noise (not used for the clean part)
        self.define_params()
        if self.interface:
            self.define_gradients()
            self.define_interface()
    
    def filter_up(self, vis, w, b=None):
        out = T.dot(vis, w)
        return out + b if b else out
    filter_down = filter_up
    
    def corrupt_input(self):
        if self.corruption_pattern is None:
            mask = self.random.binomial(T.shape(self.input), 1, 1 - self.noise_level)
        elif self.corruption_pattern == 'by_pair':
            shape = T.shape(self.input)
            # Do not ask me why, but just doing "/ 2" does not work (there is
            # a bug in the optimizer).
            shape = T.stack(shape[0], (shape[1] * 2) / 4)
            mask = self.random.binomial(shape, 1, 1 - self.noise_level)
            mask = T.horizontal_stack(mask, mask)
        else:
            raise ValueError('Unknown value for corruption_pattern: %s' % self.corruption_pattern)
        return mask * self.input
     
    def define_behavioural(self, container, input, idx_list, auxinput):
        self.define_propup(container, input, idx_list , auxinput)
        container.hidden = self.hid_fn(container.hidden_activation)
        
        self.define_propdown(container, idx_list , auxinput)
        container.rec = self.rec_fn(container.rec_activation)
        if self.input is not None:
            container.rec_in = self.rec_fn(container.rec_activation_in)
        if (self.auxinput is not None):
            container.rec_aux = self.rec_fn(container.rec_activation_aux)
    
    def define_propup(self, container, input, idx_list, auxinput):
        container.hidden_activation = self.benc
        if self.input is not None:
            container.hidden_activation += self.filter_up(input, self.wenc)
        if self.auxinput is not None:
            container.hidden_activation += scandotenc(idx_list,auxinput,self.wauxenc)
    
    def define_propdown(self, container, idx_list, auxinput):
        if self.input is not None:
            container.rec_activation_in = self.filter_down(container.hidden,self.wdec,self.bdec)
        if self.auxinput is not None:
            container.rec_activation_aux = scandotdec(idx_list,auxinput,container.hidden,self.wauxdec) +\
                    scanbiasdec(idx_list,auxinput,self.bauxdec)
        
        if (self.ignore_missing is not None and self.input is not None and not self.reconstruct_missing):
            # Apply mask to gradient to ensure we do not backpropagate on the
            # cost computed on missing inputs (that have been imputed).
            container.rec_activation_in = mask_gradient(container.rec_activation_in, self.input_missing_mask)
        
        if (self.input is not None) and (self.auxinput is not None):
            container.rec_activation = T.join(1,container.rec_activation_in,container.rec_activation_aux)
        else:
            container.rec_activation = container.rec_activation_in \
                    if self.input is not None else container.rec_activation_aux
    
    def define_regularization(self):
        self.reg_coef = T.scalar('reg_coef')
        if self.auxinput is not None:
            self.Maskup = scanmaskenc(self.idx_list,self.wauxenc)
            self.Maskdown = scanmaskdec(self.idx_list,self.wauxdec)
            if type(self.Maskup) is not list:
                self.Maskup = [self.Maskup]
            if type(self.Maskdown) is not list:
                self.Maskdown = [self.Maskdown]
        
        listweights = []
        listweightsenc = []
        if self.auxinput is not None:
            listweights += [w*m for w,m in zip(self.Maskup,self.wauxenc)] + [w*m for w,m in zip(self.Maskdown,self.wauxdec)]
            listweightsenc += [w*m for w,m in zip(self.Maskup,self.wauxenc)]
        if self.input is not None:
            listweights += [self.wenc,self.wdec]
            listweightsenc += [self.wenc]
        
        self.regularization = self.reg_coef * get_reg_cost(listweights,'l1')
        self.regularizationenc = self.reg_coef * get_reg_cost(listweightsenc,'l1')
    
    def define_cost(self, container):
        tmpbool = (self.reconstruction_cost_function_name == 'cross_entropy')
        if (self.input is not None):
            container.reconstruction_cost_in = \
                self.reconstruction_cost_function(self.input, container.rec_activation_in \
                if tmpbool else container.rec_in, self.rec_name)
        if (self.auxinput is not None):
            container.reconstruction_cost_aux = \
                self.reconstruction_cost_function(scaninputs(self.idx_list, self.auxinput), container.rec_activation_aux \
                if tmpbool else container.rec_aux, self.rec_name)
        
        # TOTAL COST
        if (self.input is not None) and (self.auxinput is not None):
            container.reconstruction_cost = self.scale_cost_in * \
                container.reconstruction_cost_in +  self.scale_cost_aux*\
                container.reconstruction_cost_aux
        else:
            if self.input is not None:
                container.reconstruction_cost = container.reconstruction_cost_in
            if (self.auxinput is not None):
                container.reconstruction_cost = container.reconstruction_cost_aux
        
        if self.regularize: #if stacked don't merge regularization and cost here but in the stackeddaaig module
            container.cost = container.reconstruction_cost + self.regularization
        else:
            container.cost = container.reconstruction_cost
    
    def define_params(self):
        if not hasattr(self,'params'):
            self.params = []
        
        self.params += [self.benc]
        self.paramsenc = copy.copy(self.params)
        
        if self.input is not None:
            self.params += [self.wenc] + [self.bdec]
            self.paramsenc += [self.wenc]
        if self.auxinput is not None:
            self.params += self.wauxenc + self.bauxdec
            self.paramsenc += self.wauxenc
        
        if not(self.tie_weights):
            if self.input is not None:
                self.params += [self.wdec]
        if not(self.tie_weights_aux):
            if self.auxinput is not None:
                self.params += self.wauxdec
    
    def define_gradients(self):
        self.gradients = T.grad(self.noise.cost, self.params)
        self.updates = dict((p, p - self.lr * g) for p, g in zip(self.params, self.gradients))
    
    def define_interface(self):
        # declare function to interface with module (if not stacked)
        listin = []
        listout = []
        if self.input is not None:
            listin += [self.input]
            listout += [self.noisy_input]
        if self.auxinput is None:
            listin += [self.idx_list, self.auxinput]
            listout += [self.noisy_auxinput]
        
        self.update = theano.Method(listin, self.noise.cost, self.updates)
        self.compute_cost = theano.Method(listin, self.noise.cost)
        self.noisify = theano.Method(listin, listout)
        self.recactivation = theano.Method(listin, self.noise.rec_activation)
        self.reconstruction = theano.Method(listin, self.noise.rec)
        self.activation = theano.Method(listin, self.clean.hidden_activation)
        self.representation = theano.Method(listin, self.clean.hidden)
    
    def _instance_initialize(self, obj, lr = 1 , reg_coef = 0, noise_level = 0 , noise_level_group = 0, scale_cost_in = 1,
                            scale_cost_aux = 1 , seed=1, orthoinit = False, tieinit = False, alloc=True, **init):
        super(DAAig, self)._instance_initialize(obj, **init)
        
        obj.reg_coef = reg_coef
        obj.noise_level = noise_level
        obj.noise_level_group = noise_level_group
        obj.scale_cost_in = scale_cost_in
        obj.scale_cost_aux = scale_cost_aux
        obj.lr = lr  if self.interface else None
        # if stacked useless (overriden by the sup_lr and unsup_lr of the stackeddaaig module)
        
        obj.random.initialize()
        obj.random.seed(seed)
        self.R = numpy.random.RandomState(seed)
        
        if self.input is not None:
            self.inf = 1/numpy.sqrt(self.in_size)
        if self.auxinput is not None:
            self.inf = 1/numpy.sqrt(sum(self.auxin_size))
        if (self.auxinput is not None) and (self.input is not None):
            self.inf = 1/numpy.sqrt(sum(self.auxin_size)+self.in_size)
        self.hif = 1/numpy.sqrt(self.n_hid)
        
        if alloc:
            if self.input is not None:
                wencshp = (self.in_size, self.n_hid)
                wdecshp = tuple(reversed(wencshp))
                obj.bdec = numpy.zeros(self.in_size)
                obj.wenc = self.R.uniform(size=wencshp, low = -self.inf, high = self.inf)
                if not(self.tie_weights):
                    obj.wdec = copy.copy(obj.wenc.T) if tieinit else \
                            self.R.uniform(size=wdecshp,low=-self.hif,high=self.hif)
                if orthoinit:
                    obj.wenc = orthogonalinit(obj.wenc)
                    if not(self.tie_weights):
                        obj.wdec = orthogonalinit(obj.wdec,0)
                print 'wencshp = ', wencshp
                print 'wdecshp = ', wdecshp
            
            if self.auxinput is not None:
                wauxencshp = [(i, self.n_hid) for i in self.auxin_size]
                wauxdecshp = [tuple(reversed(i)) for i in wauxencshp]
                obj.bauxdec = [numpy.zeros(i) for i in self.auxin_size]
                obj.wauxenc = [self.R.uniform(size=i, low = -self.inf, high = self.inf) for i in wauxencshp]
                if not(self.tie_weights_aux):
                    obj.wauxdec = [copy.copy(obj.wauxenc[i].T) for i in range(len(wauxdecshp))] if tieinit else\
                            [self.R.uniform(size=i, low=-self.hif, high=self.hif) for i in wauxdecshp]
                if orthoinit:
                    obj.wauxenc = [orthogonalinit(w) for w in obj.wauxenc]
                    if not(self.tie_weights_aux):
                        obj.wauxdec = [orthogonalinit(w,0) for w in obj.wauxdec]
                print 'wauxencshp = ', wauxencshp
                print 'wauxdecshp = ', wauxdecshp
            
            print 'self.inf = ', self.inf
            print 'self.hif = ', self.hif
            
            obj.benc = numpy.zeros(self.n_hid)
            

#-----------------------------------------------------------------------------------------------------------------------

class StackedDAAig(module.Module):
    def __init__(self, depth = 1, input = T.dmatrix('input'), auxinput = [None],
                in_size = None, auxin_size = [None], n_hid = [1],
                regularize = False, tie_weights = False, tie_weights_aux = None, hid_fn = 'tanh_act',
                rec_fn = 'tanh_act',reconstruction_cost_function='cross_entropy',
                n_out = 2, target = None, debugmethod = False, totalupdatebool=False,
                ignore_missing=None, reconstruct_missing=False,
                corruption_pattern=None,
                **init):
        
        super(StackedDAAig, self).__init__()
        
        # utils
        def listify(param,depth):
            if type(param) is list:
                return param if len(param)==depth else [param[0]]*depth
            else:
                return [param]*depth
        
        # save parameters
        self.depth = depth
        self.input = input
        self.auxinput = auxinput
        self.in_size = in_size
        auxin_size = auxin_size
        self.n_hid = listify(n_hid,depth)
        self.regularize = regularize
        tie_weights = listify(tie_weights,depth)
        tie_weights_aux = listify(tie_weights_aux,depth)
        hid_fn = listify(hid_fn,depth)
        rec_fn = listify(rec_fn,depth)
        reconstruction_cost_function = listify(reconstruction_cost_function,depth)
        self.n_out = n_out
        self.target = target if target is not None else T.lvector('target')
        self.debugmethod = debugmethod
        self.totalupdatebool = totalupdatebool
        self.ignore_missing = ignore_missing
        self.reconstruct_missing = reconstruct_missing
        self.corruption_pattern = corruption_pattern
        
        print '\t**** StackedDAAig.__init__ ****'
        print '\tdepth = ', self.depth
        print '\tinput = ', self.input
        print '\tauxinput = ', self.auxinput
        print '\tin_size = ', self.in_size
        print '\tauxin_size = ', auxin_size
        print '\tn_hid = ', self.n_hid
        print '\tregularize = ', self.regularize
        print '\ttie_weights = ', tie_weights
        print '\ttie_weights_aux = ', tie_weights_aux
        print '\thid_fn = ', hid_fn
        print '\trec_fn = ', rec_fn
        print '\treconstruction_cost_function = ', reconstruction_cost_function
        print '\tn_out = ', self.n_out
        
        # init for model construction
        inputprec = input
        in_sizeprec = in_size
        self.daaig = [None] * (self.depth+1)
        
        #hyper parameters
        self.unsup_lr = T.dscalar('unsup_lr')
        self.sup_lr = T.dscalar('sup_lr')
        
        # updatemethods
        self.localupdate = [None] * (self.depth+1) #update only on the layer parameters
        self.globalupdate = [None] * (self.depth+1)#update wrt the layer cost backproped untill the input layer
        if self.totalupdatebool:
            self.totalupdate = [None] * (self.depth+1) #update wrt all the layers cost backproped untill the input layer
        
        # facultative methods
        if self.debugmethod:
            self.activation = [None] * (self.depth)
            self.representation = [None] * (self.depth)
            self.recactivation = [None] * (self.depth)
            self.reconstruction = [None] * (self.depth)
            self.noisyinputs = [None] * (self.depth)
            self.compute_localgradients_in = [None] * (self.depth)
            self.compute_localgradients_aux = [None] * (self.depth)
            self.compute_localcost = [None] * (self.depth+1)
            self.compute_localgradients = [None] * (self.depth+1)
            self.compute_globalcost = [None] * (self.depth+1)
            self.compute_globalgradients = [None] * (self.depth+1)
            if self.totalupdatebool:
                self.compute_totalcost = [None] * (self.depth+1)
                self.compute_totalgradients = [None] * (self.depth+1)
        
        # some theano Variables we want to keep track on
        self.localgradients_in = [None] * (self.depth)
        self.localgradients_aux = [None] * (self.depth)
        self.localcost = [None] * (self.depth+1)
        self.localgradients = [None] * (self.depth+1)
        self.globalcost = [None] * (self.depth+1)
        self.globalgradients = [None] * (self.depth+1)
        if self.regularize:
            self.regularizationenccost = [None] * (self.depth)
        if self.totalupdatebool:
            self.totalcost = [None] * (self.depth+1)
            self.totalgradients = [None] * (self.depth+1)
        
        #params to update and inputs initialization
        paramstot = []
        paramsenc = []
        self.inputs = [None] * (self.depth+1)
        if self.input is not None:
            self.inputs[0] = [self.input]
        else:
            self.inputs[0] = []
        
        offset = 0
        for i in range(self.depth):
            
            dict_params = dict(input = inputprec, in_size = in_sizeprec, auxin_size = auxin_size[i],
                    n_hid = self.n_hid[i], regularize = False, tie_weights = tie_weights[i],
                    tie_weights_aux = tie_weights_aux[i], hid_fn = hid_fn[i],
                    rec_fn = rec_fn[i], reconstruction_cost_function = reconstruction_cost_function[i],
                    interface = False, ignore_missing = self.ignore_missing,
                    reconstruct_missing = self.reconstruct_missing,corruption_pattern = self.corruption_pattern)
            if auxin_size[i] is None:
                offset +=1
                dict_params.update({'auxinput' : None})
            else:
                dict_params.update({'auxinput' : self.auxinput[i-offset]})
            
            print '\tLayer init= ', i+1
            self.daaig[i] = DAAig(**dict_params)
            
            # method input, outputs and parameters update
            if i:
                self.inputs[i] = copy.copy(self.inputs[i-1])
            if auxin_size[i] is not None:
                self.inputs[i] += [self.daaig[i].idx_list,self.auxinput[i-offset]]
            
            noisyout = []
            if inputprec is not None:
                noisyout += [self.daaig[i].noisy_input]
            if auxin_size[i] is not None:
                noisyout += [self.daaig[i].noisy_auxinput]
            
            paramstot += self.daaig[i].params
            
            # save the costs
            self.localcost[i] = self.daaig[i].noise.cost
            self.globalcost[i] = self.daaig[i].noise.cost
            if self.totalupdatebool:
                self.totalcost[i] = self.totalcost[i-1] + self.daaig[i].noise.cost if i else self.daaig[i].noise.cost
            
            if self.regularize:
                self.regularizationenccost[i] = self.regularizationenccost[i-1]+self.daaig[i-1].regularizationenc if i else 0
                self.localcost[i] += self.daaig[i].regularization
                self.globalcost[i] += self.regularizationenccost[i] + self.daaig[i].regularization
                if self.totalupdatebool:
                    self.totalcost[i] += self.daaig[i].regularization
            
            self.localgradients_in[i] = T.grad(self.daaig[i].noise.reconstruction_cost_in, self.daaig[i].params) \
                if inputprec is not None else T.constant(0)
            self.localgradients_aux[i] = T.grad(self.daaig[i].noise.reconstruction_cost_aux,self.daaig[i].params) \
                if auxin_size[i] is not None else T.constant(0)
            self.localgradients[i] = T.grad(self.localcost[i], self.daaig[i].params)
            self.globalgradients[i] = T.grad(self.globalcost[i], paramsenc + self.daaig[i].params)
            if self.totalupdatebool:
                self.totalgradients[i] = T.grad(self.totalcost[i], paramstot)
            
            #create the updates dictionnaries
            local_grads = dict((j,j-self.unsup_lr*g) for j,g in zip(self.daaig[i].params,self.localgradients[i]))
            global_grads = dict((j,j-self.unsup_lr*g) for j,g in zip(paramsenc+self.daaig[i].params,self.globalgradients[i]))
            if self.totalupdatebool:
                total_grads = dict((j, j - self.unsup_lr * g) for j,g in zip(paramstot,self.totalgradients[i]))
            
            # method declaration
            self.localupdate[i] = theano.Method(self.inputs[i],self.localcost[i],local_grads)
            self.globalupdate[i] = theano.Method(self.inputs[i],self.globalcost[i],global_grads)
            if self.totalupdatebool:
                self.totalupdate[i] = theano.Method(self.inputs[i],self.totalcost[i],total_grads)
            
            if self.debugmethod:
                self.activation[i] = theano.Method(self.inputs[i],self.daaig[i].clean.hidden_activation)
                self.representation[i] = theano.Method(self.inputs[i],self.daaig[i].clean.hidden)
                self.recactivation[i] = theano.Method(self.inputs[i],self.daaig[i].noise.rec_activation)
                self.reconstruction[i] = theano.Method(self.inputs[i],self.daaig[i].noise.rec)
                self.noisyinputs[i] =theano.Method(self.inputs[i], noisyout)
                self.compute_localcost[i] = theano.Method(self.inputs[i],self.localcost[i])
                self.compute_localgradients[i] = theano.Method(self.inputs[i],self.localgradients[i])
                self.compute_localgradients_in[i] = theano.Method(self.inputs[i],self.localgradients_in[i])
                self.compute_localgradients_aux[i] = theano.Method(self.inputs[i],self.localgradients_aux[i])
                self.compute_globalcost[i] = theano.Method(self.inputs[i],self.globalcost[i])
                self.compute_globalgradients[i] = theano.Method(self.inputs[i],self.globalgradients[i])
                if self.totalupdatebool:
                    self.compute_totalcost[i] = theano.Method(self.inputs[i],self.totalcost[i])
                    self.compute_totalgradients[i] = theano.Method(self.inputs[i],self.totalgradients[i])
            
            paramsenc += self.daaig[i].paramsenc
            inputprec = self.daaig[i].clean.hidden
            in_sizeprec = self.n_hid[i]
        
        # supervised layer------------------------------------------------------------------------
        print '\tLayer supervised init'
        self.inputs[-1] = copy.copy(self.inputs[-2])+[self.target]
        self.daaig[-1] = LogRegN(in_sizeprec,self.n_out,sigmoid_act(self.daaig[-2].clean.hidden_activation),self.target)
        paramstot += self.daaig[-1].params
        
        self.localcost[-1] = self.daaig[-1].regularized_cost \
                if self.regularize else self.daaig[-1].unregularized_cost
        self.globalcost[-1] = self.daaig[-1].regularized_cost + self.regularizationenccost[-1] \
                if self.regularize else self.daaig[-1].unregularized_cost
        
        if self.totalupdatebool:
            self.totalcost[-1] = [self.totalcost[-2], self.globalcost[-1]]
        
        self.localgradients[-1] = T.grad(self.localcost[-1], self.daaig[-1].params)
        self.globalgradients[-1] = T.grad(self.globalcost[-1], paramsenc + self.daaig[-1].params)
        if self.totalupdatebool:
            self.totalgradients[-1] = [T.grad(self.totalcost[-2], paramstot) , T.grad(self.globalcost[-1],paramstot) ]
        
        local_grads = dict((j,j-self.sup_lr*g) for j,g in zip(self.daaig[-1].params,self.localgradients[-1]))
        global_grads = dict((j,j-self.sup_lr*g) for j,g in zip(paramsenc + self.daaig[-1].params,self.globalgradients[-1]))
        if self.totalupdatebool:
            total_grads = dict((j, j - self.unsup_lr * g1 - self.sup_lr * g2)\
                    for j,g1,g2 in zip(paramstot,self.totalgradients[-1][0],self.totalgradients[-1][1]))
        
        self.localupdate[-1] = theano.Method(self.inputs[-1],self.localcost[-1],local_grads)
        self.globalupdate[-1] = theano.Method(self.inputs[-1],self.globalcost[-1],global_grads)
        if self.totalupdatebool:
            self.totalupdate[-1] = theano.Method(self.inputs[-1],self.totalcost[-1],total_grads)
            # total update of each local cost [no global cost backpropagated]
            totallocal_grads={}
            for k in range(self.depth):
                totallocal_grads.update(dict((j, j - self.unsup_lr * g) for j,g in \
                        zip(self.daaig[k].params,self.localgradients[k])))
            totallocal_grads.update(dict((j, j - self.sup_lr * g) for j,g in
                    zip(self.daaig[-1].params,self.localgradients[-1])))
            self.totallocalupdate = theano.Method(self.inputs[-1],self.localcost,totallocal_grads)
        
        # interface for the user
        self.classify = theano.Method(self.inputs[-2],self.daaig[-1].argmax_standalone)
        self.NLL = theano.Method(self.inputs[-1],self.daaig[-1]._xent)
        
        if self.debugmethod:
            self.compute_localcost[-1] = theano.Method(self.inputs[-1],self.localcost[-1])
            self.compute_localgradients[-1] = theano.Method(self.inputs[-1],self.localgradients[-1])
            self.compute_globalcost[-1] = theano.Method(self.inputs[-1],self.globalcost[-1])
            self.compute_globalgradients[-1] = theano.Method(self.inputs[-1],self.globalgradients[-1])
            if self.totalupdatebool:
                self.compute_totalcost[-1] = theano.Method(self.inputs[-1],self.totalcost[-1])
                self.compute_totalgradients[-1] =\
                        theano.Method(self.inputs[-1],self.totalgradients[-1][0]+self.totalgradients[-1][1])
    
    def _instance_initialize(self,inst,unsup_lr = 0.01, sup_lr = 0.01, reg_coef = 0, scale_cost_in = 1, scale_cost_aux = 1,
                                noise_level = 0 , noise_level_group = 0, seed = 1, orthoinit = False, tieinit=False,
                                alloc = True,**init):
        super(StackedDAAig, self)._instance_initialize(inst, **init)
        
        inst.unsup_lr = unsup_lr
        inst.sup_lr = sup_lr
        
        for i in range(self.depth):
            print '\tLayer = ', i+1
            inst.daaig[i].initialize(reg_coef = reg_coef[i] if type(reg_coef) is list else reg_coef, \
                    noise_level = noise_level[i] if type(noise_level) is list else noise_level, \
                    scale_cost_in = scale_cost_in[i] if type(scale_cost_in) is list else scale_cost_in, \
                    scale_cost_aux = scale_cost_aux[i] if type(scale_cost_aux) is list else scale_cost_aux, \
                    noise_level_group = noise_level_group[i] if type(noise_level_group) is list else noise_level_group, \
                    seed = seed + i, orthoinit = orthoinit, tieinit = tieinit, alloc = alloc)
        
        print '\tLayer supervised'
        inst.daaig[-1].initialize()
        
        if alloc:
            inst.daaig[-1].R = numpy.random.RandomState(seed+self.depth)
            # init the logreg weights
            inst.daaig[-1].w = inst.daaig[-1].R.uniform(size=inst.daaig[-1].w.shape,\
                    low = -1/numpy.sqrt(inst.daaig[-2].n_hid), high = 1/numpy.sqrt(inst.daaig[-2].n_hid))
            if orthoinit:
                inst.daaig[-1].w = orthogonalinit(inst.daaig[-1].w)
        inst.daaig[-1].l1 = reg_coef[-1] if type(reg_coef) is list else reg_coef
        inst.daaig[-1].l2 = 0
        #only l1 norm for regularisation to be consitent with the unsup regularisation
    
    def _instance_save(self,inst,save_dir=''):
        
        for i in range(self.depth):
            save_mat('benc%s.ft'%(i) ,inst.daaig[i].benc, save_dir)
            
            if self.daaig[i].auxinput is not None:
                for j in range(len(inst.daaig[i].wauxenc)):
                    save_mat('wauxenc%s_%s.ft'%(i,j) ,inst.daaig[i].wauxenc[j], save_dir)
                    save_mat('bauxdec%s_%s.ft'%(i,j) ,inst.daaig[i].bauxdec[j], save_dir)
            
            if self.daaig[i].input is not None:
                save_mat('wenc%s.ft'%(i) ,inst.daaig[i].wenc, save_dir)
                save_mat('bdec%s.ft'%(i) ,inst.daaig[i].bdec, save_dir)
            
            if not self.daaig[i].tie_weights_aux:
                if self.daaig[i].auxinput is not None:
                    for j in range(len(inst.daaig[i].wauxdec)):
                        save_mat('wauxdec%s_%s.ft'%(i,j) ,inst.daaig[i].wauxdec[j], save_dir)
            
            if not self.daaig[i].tie_weights:
                if self.daaig[i].input is not None:
                    save_mat('wdec%s.ft'%(i) ,inst.daaig[i].wdec, save_dir)
        i=i+1
        save_mat('wenc%s.ft'%(i) ,inst.daaig[i].w, save_dir)
        save_mat('benc%s.ft'%(i) ,inst.daaig[i].b, save_dir)
    
    def _instance_load(self,inst,save_dir='',coefenc = None, coefdec = None, Sup_layer = None):
        
        if coefenc is None:
            coefenc = [1.]*self.depth
        if coefdec is None:
            coefdec = [1.]*self.depth
            
        for i in range(self.depth):
            inst.daaig[i].benc = load_mat('benc%s.ft'%(i), save_dir)/coefenc[i]
            
            if self.daaig[i].auxinput is not None:
                for j in range(len(inst.daaig[i].wauxenc)):
                    inst.daaig[i].wauxenc[j] = load_mat('wauxenc%s_%s.ft'%(i,j),save_dir)/coefenc[i]
                    inst.daaig[i].bauxdec[j] = load_mat('bauxdec%s_%s.ft'%(i,j),save_dir)/coefdec[i]
            
            if self.daaig[i].input is not None:
                inst.daaig[i].wenc = load_mat('wenc%s.ft'%(i),save_dir)/coefenc[i]
                inst.daaig[i].bdec = load_mat('bdec%s.ft'%(i),save_dir)/coefdec[i]
            
            if not self.daaig[i].tie_weights_aux:
                if self.daaig[i].auxinput is not None:
                    for j in range(len(inst.daaig[i].wauxdec)):
                        if 'wauxdec%s_%s.ft'%(i,j) in os.listdir(save_dir):
                            inst.daaig[i].wauxdec[j] = load_mat('wauxdec%s_%s.ft'%(i,j),save_dir)/coefdec[i]
                        else:
                            print "WARNING: no decoding 'wauxdec%s_%s.ft' file use 'wauxenc%s_%s.ft' instead"%(i,j,i,j)
                            inst.daaig[i].wauxdec[j] = numpy.transpose(load_mat('wauxenc%s_%s.ft'%(i,j),save_dir)/coefdec[i])
            
            if not self.daaig[i].tie_weights:
                if self.daaig[i].input is not None:
                    if 'wdec%s.ft'%(i) in os.listdir(save_dir):
                        inst.daaig[i].wdec = load_mat('wdec%s.ft'%(i),save_dir)/coefdec[i]
                    else:
                        print "WARNING: no decoding 'wdec%s.ft' file use 'wenc%s.ft' instead"%(i,i)
                        inst.daaig[i].wdec = numpy.transpose(load_mat('wenc%s.ft'%(i),save_dir)/coefdec[i])
        i=i+1
        if Sup_layer is None:
            inst.daaig[i].w = load_mat('wenc%s.ft'%(i),save_dir)
            inst.daaig[i].b = load_mat('benc%s.ft'%(i),save_dir)
        else:
            inst.daaig[i].w = load_mat('wenc%s.ft'%(Sup_layer),save_dir)
            inst.daaig[i].b = load_mat('benc%s.ft'%(Sup_layer),save_dir)
    
    def _instance_hidsaturation(self,inst,layer,inputs):
        return numpy.mean(numpy.median(abs(inst.activation[layer](*inputs)),1))
    
    def _instance_recsaturation(self,inst,layer,inputs):
        return numpy.mean(numpy.median(abs(inst.recactivation[layer](*inputs)),1))
    
    def _instance_error(self,inst,inputs,target):
        return numpy.sum(inst.classify(*inputs) != target) / float(len(target)) *100.0
    
    def _instance_nll(self,inst,inputs,target):
        return numpy.sum(inst.NLL(*(inputs+[target]))) / float(len(target))
    
    def _instance_unsupgrad(self,inst,inputs,layer,param_name):
        inst.noiseseed(0)
        gradin = inst.compute_localgradients_in[layer](*inputs)
        inst.noiseseed(0)
        gradaux = inst.compute_localgradients_aux[layer](*inputs)
        inst.noiseseed(0)
        gradtot = inst.compute_localgradients[layer](*inputs)
        
        for j in range(len(gradtot)):
            if str(self.daaig[layer].params[j]) is param_name:
                tmpin = numpy.sqrt((pow(inst.daaig[layer].scale_cost_in,2)*gradin[j]*gradin[j]).sum()) \
                                if type(gradin) is list else 0
                tmpaux= numpy.sqrt((pow(inst.daaig[layer].scale_cost_aux,2)*gradaux[j]*gradaux[j]).sum())\
                                 if type(gradaux) is list else 0
                tmptot = numpy.sqrt((gradtot[j]*gradtot[j]).sum()) if type(gradtot) is list else 0
                
                if type(gradin) is list and type(gradaux) is list and (gradin[j]*gradin[j]).sum() != 0:
                    projauxin =(inst.daaig[layer].scale_cost_aux*gradaux[j] * \
                                inst.daaig[layer].scale_cost_in*gradin[j]).sum()/ \
                                (numpy.sqrt((pow(inst.daaig[layer].scale_cost_in,2)*gradin[j]*gradin[j]).sum()))
                else:
                    projauxin = 0
                return tmpin, tmpaux, tmptot, tmpin/(tmpaux+tmpin)*100, projauxin/tmpaux*100 if tmpaux != 0 else 0
    
    def _instance_noiseseed(self,inst,seed):
        scannoise.R.rand.seed(seed)
        for i in range(self.depth):
            inst.daaig[i].random.seed(seed+i+1)
    
    def _instance_unsupupdate(self,inst,data,layer='all',typeup = 'local',printcost = False):
        cost = [None]*self.depth
        if typeup == 'totallocal':
            cost[-1] = inst.totallocalupdate(*data)
        else: 
            if typeup == 'total':
                if layer == 'all':
                    cost[-1] = inst.totalupdate[-1](*data)
                else:
                    cost[layer] = inst.totalupdate[layer](*data)
            else:
                if layer is 'all':
                    for i in range(self.depth):
                        if typeup == 'local':
                            cost[i] = inst.localupdate[i](*data[i])
                        if typeup == 'global':
                            cost[i] = inst.globalupdate[i](*data[i])
                else:
                    if typeup == 'local':
                        cost[layer] = inst.localupdate[layer](*data)
                    if typeup == 'global':
                        cost[layer] = inst.globalupdate[layer](*data)
        if printcost:
            print cost
        return cost
    
    def _instance_supupdate(self,inst,data,typeup = 'global',printcost = False):
        if typeup == 'local':
            cost = inst.localupdate[-1](*data)
        if typeup == 'global':
            cost = inst.globalupdate[-1](*data)
        if printcost:
            print cost
        return cost