view pylearn/algorithms/sandbox/DAA_inputs_groups.py @ 792:961dc1a7921b

added a save and load mecanism and fixed bug to DAA_inputs_groups
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Fri, 10 Jul 2009 18:14:36 -0400
parents 166a89917669
children 4e70f509ec01
line wrap: on
line source

import numpy
import theano
import copy

from theano import tensor as T
from theano.compile import module

from pylearn.sandbox.scan_inputs_groups import scaninputs, scandotdec, scandotenc, scannoise, scanbiasdec, \
        scanmaskenc,scanmaskdec, FillMissing, mask_gradient

from pylearn.algorithms.logistic_regression import LogRegN

from pylearn.io import filetensor
import os

# saving loading utils--------------------------------------------
def save_mat(fname, mat, save_dir=''):
    assert isinstance(mat, numpy.ndarray)
    print 'save ndarray to file: ', fname
    file_handle = open(os.path.join(save_dir, fname), 'w')
    filetensor.write(file_handle, mat)
    file_handle.close()

def load_mat(fname, save_dir=''):
    print 'loading ndarray from file: ', fname
    file_handle = open(os.path.join(save_dir,fname), 'r')
    rval = filetensor.read(file_handle)
    file_handle.close()
    return rval


# Initialize containers:
class CreateContainer:
    pass

# regularisation utils:-------------------------------------------
def lnorm(param, type='l2'):
    if type == 'l1':
        return T.sum(T.abs(param))
    if type == 'l2':
        return T.sum(param*param)
    raise NotImplementedError('Only l1 and l2 regularization are currently implemented')

def get_reg_cost(params, type):
    rcost = 0
    for param in params:
        rcost += lnorm(param, type)
    return rcost

# activations utils:----------------------------------------------
def sigmoid_act(x):
    return theano.tensor.nnet.sigmoid(x)

#tanh is scaled by 2 to have the same gradient than sigmoid [sigmoid(x)=(tanh(x/2.0)+1)/2.0]
def tanh_act(x):
    return theano.tensor.tanh(x/2.0)

# costs utils:---------------------------------------------------
# in order to fix numerical instability of the cost and gradient calculation for the cross entropy we calculate it
# with the following functions direclty from the activation:
def sigmoid_cross_entropy(target, output_act, mean_axis, sum_axis):
    XE =-target * T.log(1 + T.exp(-output_act)) + (1 - target) * (- T.log(1 + T.exp(output_act)))
    return -T.mean(T.sum(XE, axis=sum_axis),axis=mean_axis)

def tanh_cross_entropy(target, output_act, mean_axis, sum_axis):
    XE =-(target+1)/2.0 * T.log(1 + T.exp(- output_act)) + \
            (1 - (target+1)/2.0) * (- T.log(1 + T.exp(output_act)))
    return -T.mean(T.sum(XE, axis=sum_axis),axis=mean_axis)

def cross_entropy(target, output_act, act, mean_axis=0, sum_axis=1):
    if act == 'sigmoid_act':
        return sigmoid_cross_entropy(target, output_act, mean_axis, sum_axis)
    if act == 'tanh_act':
        return tanh_cross_entropy(target, output_act, mean_axis, sum_axis)
    assert False

def quadratic(target, output, act, axis = 1):
    return pylearn.algorithms.cost.quadratic(target, output, axis)

# DAAig module----------------------------------------------------------------
class DAAig(module.Module):
    """De-noising Auto-encoder with inputs groups and missing values
    """
    
    def __init__(self, input = None, auxinput = None,
                in_size=None, auxin_size= None, n_hid=1,
                regularize = False, tie_weights = False, hid_fn = 'tanh_act',
                rec_fn = 'tanh_act',reconstruction_cost_function='cross_entropy',
                scale_cost = False, interface = True, ignore_missing=None, reconstruct_missing=False,
                corruption_pattern=None, **init):
        """
        :param input: WRITEME
        :param auxinput: WRITEME
        :param in_size: WRITEME
        :param auxin_size: WRITEME
        :param n_hid: WRITEME
        :param regularize: WRITEME
        :param tie_weights: WRITEME
        :param hid_fn: WRITEME
        :param rec_fn: WRITEME
        :param reconstruction_cost_function: WRITEME
        :param scale_cost: WRITEME
        :param interface: WRITEME
        :param ignore_missing: if not None, the input will be scanned in order
            to detect missing values, and these values will be replaced. Also,
            the reconstruction cost's gradient will be computed only on non
            missing components. The value of this parameter indicates how to
            replace missing values:
                - some numpy.ndarray: value of this array at the same index
                - a constant: this same value everywhere
            If None, the presence of missing values may cause crashes or other
            weird and unexpected behavior.
            Please note that this option only affects the permanent input, not
            auxilary ones (that should never contain missing values). In fact,
            in the current implementation, auxiliary inputs cannot be used when
            this option is True.
        :param corruption_pattern: if not None, may specify a particular way to
            corrupt the input with missing values. Valid choices are:
            - 'by_pair': consider that features are given as pairs, and corrupt
            (or not) the whole pair instead of considering them independently.
            Elements in a pair are not consecutive, instead they are assumed to
            be at distance (total number of features / 2) of each other.
        :param reconstruct_missing: if True, then the reconstruction cost on
        missing inputs will be backpropagated. Otherwise, it will not.
        :todo: Default noise level for all daa levels
        """
        print '\t\t**** DAAig.__init__ ****'
        print '\t\tinput = ', input
        print '\t\tauxinput = ', auxinput
        print '\t\tin_size = ', in_size
        print '\t\tauxin_size = ', auxin_size
        print '\t\tn_hid = ', n_hid
        print '\t\tregularize = ', regularize
        print '\t\ttie_weights = ', tie_weights
        print '\t\thid_fn = ', hid_fn
        print '\t\trec_fn = ', rec_fn
        print '\t\treconstruction_cost_function = ', reconstruction_cost_function
        print '\t\tscale_cost = ', scale_cost
        
        super(DAAig, self).__init__()
        self.random = T.RandomStreams()
        
        # MODEL CONFIGURATION
        self.in_size = in_size
        self.auxin_size = auxin_size
        self.n_hid = n_hid
        self.regularize = regularize
        self.tie_weights = tie_weights
        self.interface = interface
        self.ignore_missing = ignore_missing
        self.reconstruct_missing = reconstruct_missing
        self.corruption_pattern = corruption_pattern
        self.scale_cost = scale_cost
        
        assert hid_fn in ('sigmoid_act','tanh_act')
        self.hid_fn = eval(hid_fn)
        
        assert rec_fn in ('sigmoid_act','tanh_act')
        self.rec_fn = eval(rec_fn)
        self.rec_name = rec_fn
        
        assert reconstruction_cost_function in ('cross_entropy','quadratic')
        self.reconstruction_cost_function = eval(reconstruction_cost_function)
        self.reconstruction_cost_function_name = reconstruction_cost_function
        
        ### DECLARE MODEL VARIABLES and default
        self.input = input
        self.noisy_input = None
        if self.ignore_missing is not None and self.input is not None:
            no_missing = FillMissing(self.ignore_missing)(self.input)
            self.input = no_missing[0]  # With missing values replaced.
            self.input_missing_mask = no_missing[1] # Missingness pattern.
        else:
            self.input_missing_mask = None
        self.auxinput = auxinput
        self.idx_list = T.ivector('idx_list') if self.auxinput is not None else None
        self.noisy_idx_list, self.noisy_auxinput = None, None
        
        #parameters
        self.benc = T.dvector('benc')
        if self.input is not None:
            self.wenc = T.dmatrix('wenc')
            self.wdec = self.wenc.T if tie_weights else T.dmatrix('wdec')
            self.bdec = T.dvector('bdec')
        
        if self.auxinput is not None:
            self.wauxenc = [T.dmatrix('wauxenc%s'%i) for i in range(len(auxin_size))]
            self.wauxdec = [self.wauxenc[i].T if tie_weights else T.dmatrix('wauxdec%s'%i) for i in range(len(auxin_size))]
            self.bauxdec = [T.dvector('bauxdec%s'%i) for i in range(len(auxin_size))]
        
        #hyper-parameters
        if self.interface:
            self.lr = T.scalar('lr')
        
        self.noise_level = T.scalar('noise_level')
        self.noise_level_group = T.scalar('noise_level_group')
        
        # leave the chance for subclasses to initialize
        if self.__class__ == DAAig:
            self.init_behavioural()
        print '\t\t**** end DAAig.__init__ ****'
    
    ### BEHAVIOURAL MODEL
    def init_behavioural(self):
        if self.input is not None:
            self.noisy_input = self.corrupt_input()
        if self.auxinput is not None:
            self.noisy_idx_list , self.noisy_auxinput = \
                scannoise(self.idx_list, self.auxinput,self.noise_level, self.noise_level_group)
        
        self.noise = CreateContainer()
        self.clean = CreateContainer()
        
        self.define_behavioural(self.clean, self.input, self.idx_list, self.auxinput)
        self.define_behavioural(self.noise, self.noisy_input, self.noisy_idx_list, self.noisy_auxinput)
        
        self.define_regularization()  # call before cost
        self.define_cost(self.noise)  # the cost is only needed for the noise (not used for the clean part)
        self.define_params()
        if self.interface:
            self.define_gradients()
            self.define_interface()
     
    def define_behavioural(self, container, input, idx_list, auxinput):
        self.define_propup(container, input, idx_list , auxinput)
        container.hidden = self.hid_fn(container.hidden_activation)
        self.define_propdown(container, idx_list , auxinput)
        if self.input is not None:
            container.rec_in = self.rec_fn(container.rec_activation_in)
        if (self.auxinput is not None):
            container.rec_aux = self.rec_fn(container.rec_activation_aux)
        container.rec = self.rec_fn(container.rec_activation)
    
    def define_propup(self, container, input, idx_list, auxinput):
        if self.input is not None:
            container.hidden_activation = self.filter_up(input, self.wenc, self.benc)
            if self.auxinput is not None:
                container.hidden_activation += scandotenc(idx_list,auxinput,self.wauxenc)
        else:
            if self.auxinput is not None:
                container.hidden_activation = scandotenc(idx_list,auxinput,self.wauxenc) + self.benc
    
    def define_propdown(self, container, idx_list, auxinput):
        if self.input is not None:
            container.rec_activation_in = self.filter_down(container.hidden,self.wdec,self.bdec)
        if self.auxinput is not None:
            container.rec_activation_aux = scandotdec(idx_list,auxinput,container.hidden,self.wauxdec) +\
                    scanbiasdec(idx_list,auxinput,self.bauxdec)
        
        if (self.ignore_missing is not None and self.input is not None and not self.reconstruct_missing):
            # Apply mask to gradient to ensure we do not backpropagate on the
            # cost computed on missing inputs (that have been imputed).
            container.rec_activation_in = mask_gradient(container.rec_activation_in,
                    self.input_missing_mask)
        
        if (self.input is not None) and (self.auxinput is not None):
            container.rec_activation = T.join(1,container.rec_activation_in,container.rec_activation_aux)
        else:
            if self.input is not None:
                container.rec_activation = container.rec_activation_in
            if (self.auxinput is not None):
                container.rec_activation = container.rec_activation_aux
    
    def filter_up(self, vis, w, b=None):
        out = T.dot(vis, w)
        return out + b if b else out
    filter_down = filter_up
    
    def define_regularization(self):
        self.reg_coef = T.scalar('reg_coef')
        if self.auxinput is not None:
            self.Maskup = scanmaskenc(self.idx_list,self.wauxenc)
            self.Maskdown = scanmaskdec(self.idx_list,self.wauxdec)
            if type(self.Maskup) is not list:
                self.Maskup = [self.Maskup]
            if type(self.Maskdown) is not list:
                self.Maskdown = [self.Maskdown]
        listweights = []
        listweightsenc = []
        if self.auxinput is not None:
            listweights += [w*m for w,m in zip(self.Maskup,self.wauxenc)] + [w*m for w,m in zip(self.Maskdown,self.wauxdec)]
            listweightsenc += [w*m for w,m in zip(self.Maskup,self.wauxenc)]
        if self.input is not None:
            listweights += [self.wenc,self.wdec]
            listweightsenc += [self.wenc]
        self.regularization = self.reg_coef * get_reg_cost(listweights,'l2')
        self.regularizationenc = self.reg_coef * get_reg_cost(listweightsenc,'l2')
    
    def define_cost(self, container):
        if self.reconstruction_cost_function_name == 'cross_entropy':
            if (self.input is not None):
                container.reconstruction_cost_in = \
                    self.reconstruction_cost_function(self.input,container.rec_activation_in,self.rec_name)
            if (self.auxinput is not None):
                container.reconstruction_cost_aux = \
                    self.reconstruction_cost_function(scaninputs(self.idx_list,self.auxinput),container.rec_activation_aux,\
                    self.rec_name)
        else:
            if (self.input is not None):
                container.reconstruction_cost_in = \
                    self.reconstruction_cost_function(self.input,container.rec_in,self.rec_name)
            if (self.auxinput is not None):
                container.reconstruction_cost_aux = \
                    self.reconstruction_cost_function(scaninputs(self.idx_list,self.auxinput),container.rec_aux,\
                    self.rec_name)
        # TOTAL COST
        if (self.input is not None) and (self.auxinput is not None):
            container.reconstruction_cost = (T.constant(min(1,1+self.scale_cost)) *container.reconstruction_cost_in +\
                T.constant(min(1,1-self.scale_cost)) * container.reconstruction_cost_aux )
        else:
            if self.input is not None:
                container.reconstruction_cost = container.reconstruction_cost_in
            if (self.auxinput is not None):
                container.reconstruction_cost = container.reconstruction_cost_aux
        
        if self.regularize: #if stacked don't merge regularization and cost here but in the stackeddaaig module
            container.cost = container.reconstruction_cost + self.regularization
        else:
            container.cost = container.reconstruction_cost
    
    def define_params(self):
        if not hasattr(self,'params'):
            self.params = []
        self.params += [self.benc]
        self.paramsenc = copy.copy(self.params)
        if self.input is not None:
            self.params += [self.wenc] + [self.bdec]
            self.paramsenc += [self.wenc]
        if self.auxinput is not None:
            self.params += self.wauxenc + self.bauxdec
            self.paramsenc += self.wauxenc
        if not(self.tie_weights):
            if self.input is not None:
                self.params += [self.wdec]
            if self.auxinput is not None:
                self.params += self.wauxdec
    
    def define_gradients(self):
        self.gradients = T.grad(self.noise.cost, self.params)
        self.updates = dict((p, p - self.lr * g) for p, g in \
                zip(self.params, self.gradients))
    
    def define_interface(self):
        # declare function to interface with module (if not stacked)
        if self.input is None:
            listin = [self.idx_list, self.auxinput]
        if self.auxinput is None:
            listin = [self.input]
        if (self.input is not None) and (self.auxinput is not None):
            listin =[self.input,self.idx_list, self.auxinput]
        self.update = theano.Method(listin, self.noise.cost, self.updates)
        self.compute_cost = theano.Method(listin, self.noise.cost)
        if self.input is not None:
            self.noisify = theano.Method(listin, self.noisy_input)
        if self.auxinput is not None:
            self.auxnoisify = theano.Method(listin, self.noisy_auxinput)
        self.reconstruction = theano.Method(listin, self.clean.rec)
        self.representation = theano.Method(listin, self.clean.hidden)
    
    def corrupt_input(self):
        if self.corruption_pattern is None:
            mask = self.random.binomial(T.shape(self.input), 1, 1 - self.noise_level)
        elif self.corruption_pattern == 'by_pair':
            shape = T.shape(self.input)
            # Do not ask me why, but just doing "/ 2" does not work (there is
            # a bug in the optimizer).
            shape = T.stack(shape[0], (shape[1] * 2) / 4)
            mask = self.random.binomial(shape, 1, 1 - self.noise_level)
            mask = T.horizontal_stack(mask, mask)
        else:
            raise ValueError('Unknown value for corruption_pattern: %s'
                    % self.corruption_pattern)
        return mask * self.input
    
    def _instance_initialize(self, obj, lr = 1 , reg_coef = 0, noise_level = 0 , noise_level_group = 0,
                            seed=1, alloc=True, **init):
        super(DAAig, self)._instance_initialize(obj, **init)
        
        obj.reg_coef = reg_coef
        obj.noise_level = noise_level
        obj.noise_level_group = noise_level_group
        if self. interface:
            obj.lr = lr # if stacked useless (overriden by the sup_lr and unsup_lr of the stackeddaaig module)
        else:
            obj.lr = None
        
        obj.random.initialize()
        if seed is not None:
            obj.random.seed(seed)
        self.R = numpy.random.RandomState(seed)
        
        obj.__hide__ = ['params']
        
        if self.input is not None:
            self.inf = 1/numpy.sqrt(self.in_size)
        if self.auxinput is not None:
            self.inf = 1/numpy.sqrt(sum(self.auxin_size))
        if (self.auxinput is not None) and (self.input is not None):
            self.inf = 1/numpy.sqrt(sum(self.auxin_size)+self.in_size)
        self.hif = 1/numpy.sqrt(self.n_hid)
        
        if alloc:
            if self.input is not None:
                wencshp = (self.in_size, self.n_hid)
                wdecshp = tuple(reversed(wencshp))
                print 'wencshp = ', wencshp
                print 'wdecshp = ', wdecshp
                obj.wenc = self.R.uniform(size=wencshp, low = -self.inf, high = self.inf)
                if not(self.tie_weights):
                    obj.wdec = self.R.uniform(size=wdecshp, low=-self.hif, high=self.hif)
                obj.bdec = numpy.zeros(self.in_size)
            
            if self.auxinput is not None:
                wauxencshp = [(i, self.n_hid) for i in self.auxin_size]
                wauxdecshp = [tuple(reversed(i)) for i in wauxencshp]
                print 'wauxencshp = ', wauxencshp
                print 'wauxdecshp = ', wauxdecshp
                obj.wauxenc = [self.R.uniform(size=i, low = -self.inf, high = self.inf) for i in wauxencshp]
                if not(self.tie_weights):
                    obj.wauxdec = [self.R.uniform(size=i, low=-self.hif, high=self.hif) for i in wauxdecshp]
                obj.bauxdec = [numpy.zeros(i) for i in self.auxin_size]
            
            print 'self.inf = ', self.inf
            print 'self.hif = ', self.hif
            
            obj.benc = numpy.zeros(self.n_hid)
            

#-----------------------------------------------------------------------------------------------------------------------

class StackedDAAig(module.Module):
    def __init__(self, depth = 1, input = T.dmatrix('input'), auxinput = [None],
                in_size = None, auxin_size = [None], n_hid = [1],
                regularize = False, tie_weights = False, hid_fn = 'tanh_act',
                rec_fn = 'tanh_act',reconstruction_cost_function='cross_entropy', scale_cost=False,
                n_out = 2, target = None, debugmethod = False, totalupdatebool=False,
                ignore_missing=None, reconstruct_missing=False,
                corruption_pattern=None,
                **init):
        
        super(StackedDAAig, self).__init__()
        
        # utils
        def listify(param,depth):
            if type(param) is list:
                return param if len(param)==depth else [param[0]]*depth
            else:
                return [param]*depth
        
        # save parameters
        self.depth = depth
        self.input = input
        self.auxinput = auxinput
        self.in_size = in_size
        auxin_size = auxin_size
        self.n_hid = listify(n_hid,depth)
        self.regularize = regularize
        tie_weights = listify(tie_weights,depth)
        hid_fn = listify(hid_fn,depth)
        rec_fn = listify(rec_fn,depth)
        reconstruction_cost_function = listify(reconstruction_cost_function,depth)
        scale_cost = listify(scale_cost,depth)
        self.n_out = n_out
        self.target = target if target is not None else T.lvector('target')
        self.debugmethod = debugmethod
        self.totalupdatebool = totalupdatebool
        self.ignore_missing = ignore_missing
        self.reconstruct_missing = reconstruct_missing
        self.corruption_pattern = corruption_pattern
        
        print '\t**** StackedDAAig.__init__ ****'
        print '\tdepth = ', self.depth
        print '\tinput = ', self.input
        print '\tauxinput = ', self.auxinput
        print '\tin_size = ', self.in_size
        print '\tauxin_size = ', auxin_size
        print '\tn_hid = ', self.n_hid
        print '\tregularize = ', self.regularize
        print '\ttie_weights = ', tie_weights
        print '\thid_fn = ', hid_fn
        print '\trec_fn = ', rec_fn
        print '\treconstruction_cost_function = ', reconstruction_cost_function
        print '\tscale_cost = ', scale_cost
        print '\tn_out = ', self.n_out
        
        # init for model construction
        inputprec = input
        in_sizeprec = in_size
        self.daaig = [None] * (self.depth+1)
        
        #hyper parameters
        self.unsup_lr = T.dscalar('unsup_lr')
        self.sup_lr = T.dscalar('sup_lr')
        
        # updatemethods
        self.localupdate = [None] * (self.depth+1) #update only on the layer parameters
        self.globalupdate = [None] * (self.depth+1)#update wrt the layer cost backproped untill the input layer
        if self.totalupdatebool:
            self.totalupdate = [None] * (self.depth+1) #update wrt all the layers cost backproped untill the input layer
        
        # facultative methods
        if self.debugmethod:
            self.representation = [None] * (self.depth)
            self.reconstruction = [None] * (self.depth)
            self.noisyinputs = [None] * (self.depth)
            self.compute_localcost = [None] * (self.depth+1)
            self.compute_localgradients = [None] * (self.depth+1)
            self.compute_globalcost = [None] * (self.depth+1)
            self.compute_globalgradients = [None] * (self.depth+1)
            if self.totalupdatebool:
                self.compute_totalcost = [None] * (self.depth+1)
                self.compute_totalgradients = [None] * (self.depth+1)
        
        # some theano Variables we want to keep track on
        if self.regularize:
            self.regularizationenccost = [None] * (self.depth)
        self.localcost = [None] * (self.depth+1)
        self.localgradients = [None] * (self.depth+1)
        self.globalcost = [None] * (self.depth+1)
        self.globalgradients = [None] * (self.depth+1)
        if self.totalupdatebool:
            self.totalcost = [None] * (self.depth+1)
            self.totalgradients = [None] * (self.depth+1)
        
        #params to update and inputs initialization
        paramstot = []
        paramsenc = []
        self.inputs = [None] * (self.depth+1)
        if self.input is not None:
            self.inputs[0] = [self.input]
        else:
            self.inputs[0] = []
        
        offset = 0
        for i in range(self.depth):
            
            dict_params = dict(input = inputprec, in_size = in_sizeprec, auxin_size = auxin_size[i],
                    n_hid = self.n_hid[i], regularize = False, tie_weights = tie_weights[i], hid_fn = hid_fn[i],
                    rec_fn = rec_fn[i], reconstruction_cost_function = reconstruction_cost_function[i],
                    scale_cost = scale_cost[i], interface = False, ignore_missing = self.ignore_missing,
                    reconstruct_missing = self.reconstruct_missing,corruption_pattern = self.corruption_pattern)
            if auxin_size[i] is None:
                offset +=1
                dict_params.update({'auxinput' : None})
            else:
                dict_params.update({'auxinput' : self.auxinput[i-offset]})
            
            print '\tLayer init= ', i+1
            self.daaig[i] = DAAig(**dict_params)
            
            # method input, outputs and parameters update
            if i:
                self.inputs[i] = copy.copy(self.inputs[i-1])
            if auxin_size[i] is not None:
                self.inputs[i] += [self.daaig[i].idx_list,self.auxinput[i-offset]]
            
            noisyout = []
            if inputprec is not None:
                noisyout += [self.daaig[i].noisy_input]
            if auxin_size[i] is not None:
                noisyout += [self.daaig[i].noisy_auxinput]
            
            paramstot += self.daaig[i].params
            
            # save the costs
            self.localcost[i] = self.daaig[i].noise.cost
            self.globalcost[i] = self.daaig[i].noise.cost
            if self.totalupdatebool:
                if i:
                    self.totalcost[i] = self.totalcost[i-1] + self.daaig[i].noise.cost
                else:
                    self.totalcost[i] = self.daaig[i].noise.cost
            
            if self.regularize:
                if i:
                    self.regularizationenccost[i] = self.regularizationenccost[i-1]+self.daaig[i-1].regularizationenc
                else:
                    self.regularizationenccost[i] = 0
                self.localcost[i] += self.daaig[i].regularization
                self.globalcost[i] += self.regularizationenccost[i]
                if self.totalupdatebool:
                    self.totalcost[i] += self.daaig[i].regularization
            
            self.localgradients[i] = T.grad(self.localcost[i], self.daaig[i].params)
            self.globalgradients[i] = T.grad(self.globalcost[i], self.daaig[i].params+paramsenc)
            if self.totalupdatebool:
                self.totalgradients[i] = T.grad(self.totalcost[i], paramstot)
            
            #create the updates dictionnaries
            local_grads = dict((j, j - self.unsup_lr * g) for j,g in zip(self.daaig[i].params,self.localgradients[i]))
            global_grads = dict((j, j - self.unsup_lr * g)\
                    for j,g in zip(self.daaig[i].params+paramsenc,self.globalgradients[i]))
            if self.totalupdatebool:
                total_grads = dict((j, j - self.unsup_lr * g) for j,g in zip(paramstot,self.totalgradients[i]))
            
            # method declaration
            self.localupdate[i] = theano.Method(self.inputs[i],self.localcost[i],local_grads)
            self.globalupdate[i] = theano.Method(self.inputs[i],self.globalcost[i],global_grads)
            if self.totalupdatebool:
                self.totalupdate[i] = theano.Method(self.inputs[i],self.totalcost[i],total_grads)
            
            if self.debugmethod:
                self.representation[i] = theano.Method(self.inputs[i],self.daaig[i].clean.hidden_activation)
                self.reconstruction[i] = theano.Method(self.inputs[i],self.daaig[i].clean.rec)
                self.noisyinputs[i] =theano.Method(self.inputs[i], noisyout)
                self.compute_localcost[i] = theano.Method(self.inputs[i],self.localcost[i])
                self.compute_localgradients[i] = theano.Method(self.inputs[i],self.localgradients[i])
                self.compute_globalcost[i] = theano.Method(self.inputs[i],self.globalcost[i])
                self.compute_globalgradients[i] = theano.Method(self.inputs[i],self.globalgradients[i])
                if self.totalupdatebool:
                    self.compute_totalcost[i] = theano.Method(self.inputs[i],self.totalcost[i])
                    self.compute_totalgradients[i] = theano.Method(self.inputs[i],self.totalgradients[i])
            
            paramsenc += self.daaig[i].paramsenc
            inputprec = self.daaig[i].clean.hidden
            in_sizeprec = self.n_hid[i]
        
        # supervised layer------------------------------------------------------------------------
        print '\tLayer supervised init'
        self.inputs[-1] = copy.copy(self.inputs[-2])+[self.target]
        self.daaig[-1] = LogRegN(in_sizeprec,self.n_out,sigmoid_act(self.daaig[-2].clean.hidden_activation),self.target)
        paramstot += self.daaig[-1].params
        
        if self.regularize:
            self.localcost[-1] = self.daaig[-1].regularized_cost
            self.globalcost[-1] = self.daaig[-1].regularized_cost + self.regularizationenccost[-1]
        else:
            self.localcost[-1] = self.daaig[-1].unregularized_cost
            self.globalcost[-1] = self.daaig[-1].unregularized_cost
        
        if self.totalupdatebool:
            self.totalcost[-1] = [self.totalcost[-2], self.globalcost[-1]]
        
        self.localgradients[-1] = T.grad(self.localcost[-1], self.daaig[-1].params)
        self.globalgradients[-1] = T.grad(self.globalcost[-1], self.daaig[-1].params+paramsenc)
        if self.totalupdatebool:
            self.totalgradients[-1] = [T.grad(self.totalcost[-2], paramstot) ,\
                    T.grad(self.globalcost[-1], paramstot) ]
        
        local_grads = dict((j, j - self.sup_lr * g) for j,g in zip(self.daaig[-1].params,self.localgradients[-1]))
        global_grads = dict((j, j - self.sup_lr * g)\
                for j,g in zip(self.daaig[-1].params+paramsenc,self.globalgradients[-1]))
        if self.totalupdatebool:
            total_grads = dict((j, j - self.unsup_lr * g1 - self.sup_lr * g2)\
                    for j,g1,g2 in zip(paramstot,self.totalgradients[-1][0],self.totalgradients[-1][1]))
        
        self.localupdate[-1] = theano.Method(self.inputs[-1],self.localcost[-1],local_grads)
        self.globalupdate[-1] = theano.Method(self.inputs[-1],self.globalcost[-1],global_grads)
        if self.totalupdatebool:
            self.totalupdate[-1] = theano.Method(self.inputs[-1],self.totalcost[-1],total_grads)
        
        totallocal_grads={}
        for k in range(self.depth):
            totallocal_grads.update(dict((j, j - self.unsup_lr * g) for j,g in
                    zip(self.daaig[k].params,self.localgradients[k])))
        totallocal_grads.update(dict((j, j - self.sup_lr * g) for j,g in zip(self.daaig[-1].params,self.localgradients[-1])))
        self.totallocalupdate = theano.Method(self.inputs[-1],self.localcost,totallocal_grads)
        
        # interface for the user
        self.classify = theano.Method(self.inputs[-2],self.daaig[-1].argmax_standalone)
        self.NLL = theano.Method(self.inputs[-1],self.daaig[-1]._xent)
        
        if self.debugmethod:
            self.compute_localcost[-1] = theano.Method(self.inputs[-1],self.localcost[-1])
            self.compute_localgradients[-1] = theano.Method(self.inputs[-1],self.localgradients[-1])
            self.compute_globalcost[-1] = theano.Method(self.inputs[-1],self.globalcost[-1])
            self.compute_globalgradients[-1] = theano.Method(self.inputs[-1],self.globalgradients[-1])
            if self.totalupdatebool:
                self.compute_totalcost[-1] = theano.Method(self.inputs[-1],self.totalcost[-1])
                self.compute_totalgradients[-1] =\
                        theano.Method(self.inputs[-1],self.totalgradients[-1][0]+self.totalgradients[-1][1])
    
    def _instance_initialize(self,inst,unsup_lr = 0.1, sup_lr = 0.01, reg_coef = 0,
                                noise_level = 0 , noise_level_group = 0, seed = 1, alloc = True,**init):
        super(StackedDAAig, self)._instance_initialize(inst, **init)
        
        inst.unsup_lr = unsup_lr
        inst.sup_lr = sup_lr
        
        for i in range(self.depth):
            print '\tLayer = ', i+1
            inst.daaig[i].initialize(reg_coef = reg_coef, noise_level = noise_level,\
                    noise_level_group = noise_level_group, seed = seed + i, alloc = alloc)
        print '\tLayer supervised'
        inst.daaig[-1].initialize()
        if alloc:
            inst.daaig[-1].R = numpy.random.RandomState(seed+self.depth)
            # init the logreg weights
            inst.daaig[-1].w = inst.daaig[-1].R.uniform(size=inst.daaig[-1].w.shape,\
                                low = -1/numpy.sqrt(inst.daaig[-2].n_hid), high = 1/numpy.sqrt(inst.daaig[-2].n_hid))
        inst.daaig[-1].l1 = 0
        inst.daaig[-1].l2 = reg_coef #only l2 norm for regularisation to be consitent with the unsup regularisation
    
    def _instance_save(self,inst,save_dir=''):
        
        for i in range(self.depth):
            save_mat('benc%s.ft'%(i) ,inst.daaig[i].benc, save_dir)
            
            if self.daaig[i].auxinput is not None:
                for j in range(len(inst.daaig[i].wauxenc)):
                    save_mat('wauxenc%s_%s.ft'%(i,j) ,inst.daaig[i].wauxenc[j], save_dir)
                    save_mat('bauxdec%s_%s.ft'%(i,j) ,inst.daaig[i].bauxdec[j], save_dir)
            
            if self.daaig[i].input is not None:
                save_mat('wenc%s.ft'%(i) ,inst.daaig[i].wenc, save_dir)
                save_mat('bdec%s.ft'%(i) ,inst.daaig[i].bdec, save_dir)
            
            if not self.daaig[i].tie_weights:
                if self.daaig[i].auxinput is not None:
                    for j in range(len(inst.daaig[i].wauxdec)):
                        save_mat('wauxdec%s_%s.ft'%(i,j) ,inst.daaig[i].wauxdec[j], save_dir)
                
                if self.daaig[i].input is not None:
                    save_mat('wdec%s.ft'%(i) ,inst.daaig[i].wdec, save_dir)
        i=i+1
        save_mat('wenc%s.ft'%(i) ,inst.daaig[i].w, save_dir)
        save_mat('benc%s.ft'%(i) ,inst.daaig[i].b, save_dir)

    def _instance_load(self,inst,save_dir='',coef = None, Sup_layer = None):
        
        if coef is None:
            coef = [1]*self.depth
            
        for i in range(self.depth):
            inst.daaig[i].benc = load_mat('benc%s.ft'%(i), save_dir)/coef[i]
            
            if self.daaig[i].auxinput is not None:
                for j in range(len(inst.daaig[i].wauxenc)):
                    inst.daaig[i].wauxenc[j] = load_mat('wauxenc%s_%s.ft'%(i,j),save_dir)/coef[i]
                    inst.daaig[i].bauxdec[j] = load_mat('bauxdec%s_%s.ft'%(i,j),save_dir)/coef[i]
            
            if self.daaig[i].input is not None:
                inst.daaig[i].wenc = load_mat('wenc%s.ft'%(i),save_dir)/coef[i]
                inst.daaig[i].bdec = load_mat('bdec%s.ft'%(i),save_dir)/coef[i]
            
            if not self.daaig[i].tie_weights:
                if self.daaig[i].auxinput is not None:
                    for j in range(len(inst.daaig[i].wauxdec)):
                        inst.daaig[i].wauxdec[j] = load_mat('wauxdec%s_%s.ft'%(i,j),save_dir)/coef[i]
                
                if self.daaig[i].input is not None:
                    inst.daaig[i].wdec = load_mat('wdec%s.ft'%(i),save_dir)/coef[i]
        i=i+1
        if Sup_layer is None:
            inst.daaig[i].w = load_mat('wenc%s.ft'%(i),save_dir)
            inst.daaig[i].b = load_mat('benc%s.ft'%(i),save_dir)
        else:
            inst.daaig[i].w = load_mat('wenc%s.ft'%(Sup_layer),save_dir)
            inst.daaig[i].b = load_mat('benc%s.ft'%(Sup_layer),save_dir)