view pylearn/algorithms/sandbox/DAA_inputs_groups.py @ 778:a985baadf74d

Merge
author Foo Bar <barfoo@iro.umontreal.ca>
date Sat, 13 Jun 2009 22:02:13 -0400
parents ba055d419bcf 72ce8288a283
children 2c159439c47c
line wrap: on
line source

import numpy
import theano
import copy

from theano import tensor as T
from theano.compile import module

from pylearn.sandbox.scan_inputs_groups import scaninputs, scandotdec, scandotenc, scannoise, scanbiasdec, \
        scanmaskenc,scanmaskdec, FillMissing, mask_gradient

from pylearn.algorithms.logistic_regression import LogRegN

# used to initialize containers
class ScratchPad:
    pass

# regularisation utils:-------------------------------------------
def lnorm(param, type='l2'):
    if type == 'l1':
        return T.sum(T.abs(param))
    if type == 'l2':
        return T.sum(param*param)#faster...
        return T.sum(T.pow(param,2))
    raise NotImplementedError('Only l1 and l2 regularization are currently implemented')

def get_reg_cost(params, type):
    rcost = 0
    for param in params:
        rcost += lnorm(param, type)
    return rcost

# activations utils:----------------------------------------------
def sigmoid_act(x):
    return theano.tensor.nnet.sigmoid(x)

def tanh_act(x):
    return theano.tensor.tanh(x/2.0)

# costs utils:---------------------------------------------------

# in order to fix numerical instability of the cost and gradient calculation for the cross entropy we calculate it
# with the following functions direclty from the activation:

def sigmoid_cross_entropy(target, output_act, mean_axis, sum_axis):
    XE =-target * T.log(1 + T.exp(-output_act)) + (1 - target) * (- T.log(1 + T.exp(output_act)))
    return -T.mean(T.sum(XE, axis=sum_axis),axis=mean_axis)

def tanh_cross_entropy(target, output_act, mean_axis, sum_axis):
    XE =-(target+1)/2.0 * T.log(1 + T.exp(- output_act)) + \
            (1 - (target+1)/2.0) * (- T.log(1 + T.exp(output_act)))
    return -T.mean(T.sum(XE, axis=sum_axis),axis=mean_axis)

def cross_entropy(target, output_act, act, mean_axis=0, sum_axis=1):
    if act == 'sigmoid_act':
        return sigmoid_cross_entropy(target, output_act, mean_axis, sum_axis)
    if act == 'tanh_act':
        return tanh_cross_entropy(target, output_act, mean_axis, sum_axis)
    assert False

def quadratic(target, output, act, axis = 1):
    return pylearn.algorithms.cost.quadratic(target, output, axis)



# DAAig module----------------------------------------------------------------
class DAAig(module.Module):
    """De-noising Auto-encoder
    """
    
    def __init__(self, input = None, auxinput = None,
                in_size=None, auxin_size= None, n_hid=1,
                regularize = False, tie_weights = False, hid_fn = 'sigmoid_act',
                reconstruction_cost_function='cross_entropy', interface = True,
                ignore_missing=None, reconstruct_missing=False,
                corruption_pattern=None,
                **init):
        """
        :param regularize: WRITEME
        :param tie_weights: WRITEME
        :param hid_fn: WRITEME
        :param reconstruction_cost: Should return one cost per example (row)
        :param ignore_missing: if not None, the input will be scanned in order
            to detect missing values, and these values will be replaced. Also,
            the reconstruction cost's gradient will be computed only on non
            missing components. The value of this parameter indicates how to
            replace missing values:
                - some numpy.ndarray: value of this array at the same index
                - a constant: this same value everywhere
            If None, the presence of missing values may cause crashes or other
            weird and unexpected behavior.
            Please note that this option only affects the permanent input, not
            auxilary ones (that should never contain missing values). In fact,
            in the current implementation, auxiliary inputs cannot be used when
            this option is True.
        :param corruption_pattern: if not None, may specify a particular way to
        corrupt the input with missing values. Valid choices are:
            - 'by_pair': consider that features are given as pairs, and corrupt
            (or not) the whole pair instead of considering them independently.
            Elements in a pair are not consecutive, instead they are assumed to
            be at distance (total number of features / 2) of each other.
        :param reconstruct_missing: if True, then the reconstruction cost on
        missing inputs will be backpropagated. Otherwise, it will not.
        :todo: Default noise level for all daa levels
        """
        print '\t\t**** DAAig.__init__ ****'
        print '\t\tinput = ', input
        print '\t\tauxinput = ', auxinput
        print '\t\tin_size = ', in_size
        print '\t\tauxin_size = ', auxin_size
        print '\t\tn_hid = ', n_hid
        
        super(DAAig, self).__init__()
        self.random = T.RandomStreams()
        
        # MODEL CONFIGURATION
        self.in_size = in_size
        self.auxin_size = auxin_size
        self.n_hid = n_hid
        self.regularize = regularize
        self.tie_weights = tie_weights
        self.interface = interface
        self.ignore_missing = ignore_missing
        self.reconstruct_missing = reconstruct_missing
        self.corruption_pattern = corruption_pattern
        
        
        assert hid_fn in ('sigmoid_act','tanh_act')
        self.hid_fn = eval(hid_fn)
        self.hid_name = hid_fn
        
        assert reconstruction_cost_function in ('cross_entropy','quadratic')
        self.reconstruction_cost_function = eval(reconstruction_cost_function)
        self.reconstruction_cost_function_name = reconstruction_cost_function
        
        ### DECLARE MODEL VARIABLES and default
        self.input = input
        if self.ignore_missing is not None and self.input is not None:
            no_missing = FillMissing(self.ignore_missing)(self.input)
            self.input = no_missing[0]  # With missing values replaced.
            self.input_missing_mask = no_missing[1] # Missingness pattern.
        else:
            self.input_missing_mask = None
        self.noisy_input = None
        self.auxinput = auxinput
        self.idx_list = T.ivector('idx_list') if self.auxinput is not None else None
        self.noisy_idx_list, self.noisy_auxinput = None, None
        
        #parameters
        self.benc = T.dvector('benc')
        if self.input is not None:
            self.wenc = T.dmatrix('wenc')
            self.wdec = self.wenc.T if tie_weights else T.dmatrix('wdec')
            self.bdec = T.dvector('bdec')
        
        if self.auxinput is not None:
            self.wauxenc = [T.dmatrix('wauxenc%s'%i) for i in range(len(auxin_size))]
            self.wauxdec = [self.wauxenc[i].T if tie_weights else T.dmatrix('wauxdec%s'%i) for i in range(len(auxin_size))]
            self.bauxdec = [T.dvector('bauxdec%s'%i) for i in range(len(auxin_size))]
        
        #hyper-parameters
        if self.interface:
            self.lr = T.scalar('lr')
        self.noise_level = T.scalar('noise_level')
        self.noise_level_group = T.scalar('noise_level_group')
        
        # leave the chance for subclasses to initialize
        if self.__class__ == DAAig:
            self.init_behavioural()
        print '\t\t**** end DAAig.__init__ ****'
    
    ### BEHAVIOURAL MODEL
    def init_behavioural(self):
        if self.input is not None:
            self.noisy_input = self.corrupt_input()
        if self.auxinput is not None:
            self.noisy_idx_list , self.noisy_auxinput = \
                scannoise(self.idx_list, self.auxinput,self.noise_level,
                        self.noise_level_group)
        
        self.noise = ScratchPad()
        self.clean = ScratchPad()
        
        self.define_behavioural(self.clean, self.input, self.idx_list, self.auxinput)
        self.define_behavioural(self.noise, self.noisy_input, self.noisy_idx_list, self.noisy_auxinput)
        
        self.define_regularization()  # call before cost
        self.define_cost(self.clean)
        self.define_cost(self.noise)
        self.define_params()
        if self.interface:
            self.define_gradients()
            self.define_interface()
        
    def define_behavioural(self, container, input, idx_list, auxinput):
        self.define_propup(container, input, idx_list , auxinput)
        container.hidden = self.hid_fn(container.hidden_activation)
        self.define_propdown(container, idx_list , auxinput)
        container.rec = self.hid_fn(container.rec_activation)
       
    def define_propup(self, container, input, idx_list, auxinput):
        if self.input is not None:
            container.hidden_activation = self.filter_up(input, self.wenc, self.benc)
            if self.auxinput is not None:
                container.hidden_activation += scandotenc(idx_list,auxinput,self.wauxenc)
        else:
            if self.auxinput is not None:
                container.hidden_activation = scandotenc(idx_list,auxinput,self.wauxenc) + self.benc
        
    # DEPENDENCY: define_propup
    def define_propdown(self, container, idx_list, auxinput):
        if self.input is not None:
            rec_activation1 = self.filter_down(container.hidden,self.wdec,self.bdec)
        if self.auxinput is not None:
            rec_activation2 = scandotdec(idx_list,auxinput,container.hidden,self.wauxdec) +\
                    scanbiasdec(idx_list,auxinput,self.bauxdec)
        
        if (self.input is not None) and (self.auxinput is not None):
            container.rec_activation = T.join(1,rec_activation1,rec_activation2)
        else:
            if self.input is not None:
                container.rec_activation = rec_activation1
            else:
                container.rec_activation = rec_activation2

        if (self.ignore_missing is not None and self.input is not None and not
                self.reconstruct_missing):
            # Apply mask to gradient to ensure we do not backpropagate on the
            # cost computed on missing inputs (that have been imputed).
            container.rec_activation = mask_gradient(container.rec_activation,
                    self.input_missing_mask)
  
    def filter_up(self, vis, w, b=None):
        out = T.dot(vis, w)
        return out + b if b else out
    filter_down = filter_up
    
    # TODO: fix regularization type (outside parameter ?)
    def define_regularization(self):
        self.reg_coef = T.scalar('reg_coef')
        if self.auxinput is not None:
            self.Maskup = scanmaskenc(self.idx_list,self.wauxenc)
            self.Maskdown = scanmaskdec(self.idx_list,self.wauxdec)
            if type(self.Maskup) is not list:
                self.Maskup = [self.Maskup]
            if type(self.Maskdown) is not list:
                self.Maskdown = [self.Maskdown]
        listweights = []
        listweightsenc = []
        if self.auxinput is not None:
            listweights += [w*m for w,m in zip(self.Maskup,self.wauxenc)] + [w*m for w,m in zip(self.Maskdown,self.wauxdec)]
            listweightsenc += [w*m for w,m in zip(self.Maskup,self.wauxenc)]
        if self.input is not None:
            listweights += [self.wenc,self.wdec]
            listweightsenc += [self.wenc]
        self.regularization = self.reg_coef * get_reg_cost(listweights,'l2')
        self.regularizationenc = self.reg_coef * get_reg_cost(listweightsenc,'l2')
    
    
    # DEPENDENCY: define_behavioural, define_regularization
    def define_cost(self, container):
        if self.reconstruction_cost_function_name == 'cross_entropy':
            container.reconstruction_cost = self.reconstruction_costs(container.rec_activation)
        else:
            container.reconstruction_cost = self.reconstruction_costs(container.rec)
        # TOTAL COST
        if self.regularize: #if stacked don't merge regularization and cost here but in the stackeddaaig module
            container.cost = container.reconstruction_cost + self.regularization
        else:
            container.cost = container.reconstruction_cost
    
    # DEPENDENCY: define_cost
    def define_params(self):
        if not hasattr(self,'params'):
            self.params = []
        self.params += [self.benc]
        self.paramsenc = copy.copy(self.params)
        if self.input is not None:
            self.params += [self.wenc] + [self.bdec]
            self.paramsenc += [self.wenc]
        if self.auxinput is not None:
            self.params += self.wauxenc + self.bauxdec
            self.paramsenc += self.wauxenc
        if not(self.tie_weights):
            if self.input is not None:
                self.params += [self.wdec]
            if self.auxinput is not None:
                self.params += self.wauxdec
    
    # DEPENDENCY: define_cost, define_gradients
    def define_gradients(self):
        self.gradients = T.grad(self.noise.cost, self.params)
        self.updates = dict((p, p - self.lr * g) for p, g in \
                zip(self.params, self.gradients))
    
    
    # DEPENDENCY: define_behavioural, define_regularization, define_cost, define_gradients
    def define_interface(self):
        # declare function to interface with module (if not stacked)
        if self.input is None:
            listin = [self.idx_list, self.auxinput]
        if self.auxinput is None:
            listin = [self.input]
        if (self.input is not None) and (self.auxinput is not None):
            listin =[self.input,self.idx_list, self.auxinput]
        self.update = theano.Method(listin, self.noise.cost, self.updates)
        self.compute_cost = theano.Method(listin, self.noise.cost)
        if self.input is not None:
            self.noisify = theano.Method(listin, self.noisy_input)
        if self.auxinput is not None:
            self.auxnoisify = theano.Method(listin, self.noisy_auxinput)
        self.reconstruction = theano.Method(listin, self.clean.rec)
        self.representation = theano.Method(listin, self.clean.hidden)
        self.validate = theano.Method(listin, [self.clean.cost, self.clean.rec])
    
    def corrupt_input(self):
        if self.corruption_pattern is None:
            mask = self.random.binomial(T.shape(self.input), 1, 1 - self.noise_level)
        elif self.corruption_pattern == 'by_pair':
            shape = T.shape(self.input)
            # Do not ask me why, but just doing "/ 2" does not work (there is
            # a bug in the optimizer).
            shape = T.stack(shape[0], (shape[1] * 2) / 4)
            mask = self.random.binomial(shape, 1, 1 - self.noise_level)
            mask = T.horizontal_stack(mask, mask)
        else:
            raise ValueError('Unknown value for corruption_pattern: %s'
                    % self.corruption_pattern)
        return mask * self.input
    
    def reconstruction_costs(self, rec):
        if (self.input is not None) and (self.auxinput is not None):
            return self.reconstruction_cost_function(T.join(1,self.input,scaninputs(self.idx_list,self.auxinput)),\
                    rec, self.hid_name)
        if self.input is not None:
            return self.reconstruction_cost_function(self.input, rec, self.hid_name)
        if self.auxinput is not None:
            return self.reconstruction_cost_function(scaninputs(self.idx_list,self.auxinput), rec, self.hid_name)
        # All cases should be covered above. If not, something is wrong!
        assert False
    
    def _instance_initialize(self, obj, lr = 1 , reg_coef = 0, noise_level = 0 , noise_level_group = 0,
                            seed=1, alloc=True, **init):
        super(DAAig, self)._instance_initialize(obj, **init)
        
        obj.reg_coef = reg_coef
        obj.noise_level = noise_level
        obj.noise_level_group = noise_level_group
        if self. interface:
            obj.lr = lr # if stacked useless (overriden by the sup_lr and unsup_lr of the stackeddaaig module)
        else:
            obj.lr = None
        
        obj.random.initialize()
        if seed is not None:
            obj.random.seed(seed)
        self.R = numpy.random.RandomState(seed)
        
        obj.__hide__ = ['params']
        
        if self.input is not None:
            self.inf = 1/numpy.sqrt(self.in_size)
        if self.auxinput is not None:
            self.inf = 1/numpy.sqrt(sum(self.auxin_size))
        if (self.auxinput is not None) and (self.input is not None):
            self.inf = 1/numpy.sqrt(sum(self.auxin_size)+self.in_size)
        self.hif = 1/numpy.sqrt(self.n_hid)
        
        
        if alloc:
            if self.input is not None:
                wencshp = (self.in_size, self.n_hid)
                wdecshp = tuple(reversed(wencshp))
                print 'wencshp = ', wencshp
                print 'wdecshp = ', wdecshp
                
                obj.wenc = self.R.uniform(size=wencshp, low = -self.inf, high = self.inf)
                if not(self.tie_weights):
                    obj.wdec = self.R.uniform(size=wdecshp, low=-self.hif, high=self.hif)
                obj.bdec = numpy.zeros(self.in_size)
            
            if self.auxinput is not None:
                wauxencshp = [(i, self.n_hid) for i in self.auxin_size]
                wauxdecshp = [tuple(reversed(i)) for i in wauxencshp]
                print 'wauxencshp = ', wauxencshp
                print 'wauxdecshp = ', wauxdecshp
                
                obj.wauxenc = [self.R.uniform(size=i, low = -self.inf, high = self.inf) for i in wauxencshp]
                if not(self.tie_weights):
                    obj.wauxdec = [self.R.uniform(size=i, low=-self.hif, high=self.hif) for i in wauxdecshp]
                obj.bauxdec = [numpy.zeros(i) for i in self.auxin_size]
            
            print 'self.inf = ', self.inf
            print 'self.hif = ', self.hif
            
            obj.benc = numpy.zeros(self.n_hid)
            

#-----------------------------------------------------------------------------------------------------------------------

class StackedDAAig(module.Module):
    def __init__(self, depth = 1, input = T.dmatrix('input'), auxinput = [None],
                in_size = None, auxin_size = [None], n_hid = [1],
                regularize = False, tie_weights = False, hid_fn = 'sigmoid_act',
                reconstruction_cost_function='cross_entropy',
                n_out = 2, target = None, debugmethod = False, totalupdatebool=False,
                ignore_missing=None, reconstruct_missing=False,
                corruption_pattern=None,
                **init):
        
        super(StackedDAAig, self).__init__()
        print '\t**** StackedDAAig.__init__ ****'
        print '\tinput = ', input
        print '\tauxinput = ', auxinput
        print '\tin_size = ', in_size
        print '\tauxin_size = ', auxin_size
        print '\tn_hid = ', n_hid
        
        # save parameters
        self.depth = depth
        self.input = input
        self.auxinput = auxinput
        self.in_size = in_size
        auxin_size = auxin_size
        self.n_hid = n_hid
        self.regularize = regularize
        self.tie_weights = tie_weights
        self.hid_fn = hid_fn
        self.reconstruction_cost_function = reconstruction_cost_function
        self.n_out = n_out
        self.target = target if target is not None else T.lvector('target')
        self.debugmethod = debugmethod
        self.totalupdatebool = totalupdatebool
        self.ignore_missing = ignore_missing
        self.reconstruct_missing = reconstruct_missing
        self.corruption_pattern = corruption_pattern
        
        # init for model construction
        inputprec = input
        in_sizeprec = in_size
        self.daaig = [None] * (self.depth+1)
        
        #hyper parameters
        self.unsup_lr = T.dscalar('unsup_lr')
        self.sup_lr = T.dscalar('sup_lr')
        
        # updatemethods
        self.localupdate = [None] * (self.depth+1) #update only on the layer parameters
        self.globalupdate = [None] * (self.depth+1)#update wrt the layer cost backproped untill the input layer
        if self.totalupdatebool:
            self.totalupdate = [None] * (self.depth+1) #update wrt all the layers cost backproped untill the input layer
        #
        self.classify = None
        
        #others methods
        if self.debugmethod:
            self.representation = [None] * (self.depth)
            self.reconstruction = [None] * (self.depth)
            self.validate = [None] * (self.depth)
            self.noisyinputs = [None] * (self.depth)
            self.compute_localcost = [None] * (self.depth+1)
            self.compute_localgradients = [None] * (self.depth+1)
            self.compute_globalcost = [None] * (self.depth+1)
            self.compute_globalgradients = [None] * (self.depth+1)
            if self.totalupdatebool:
                self.compute_totalcost = [None] * (self.depth+1)
                self.compute_totalgradients = [None] * (self.depth+1)
        #
        
        # some theano Variables we want to keep track on
        if self.regularize:
            self.regularizationenccost = [None] * (self.depth)
        self.localcost = [None] * (self.depth+1)
        self.localgradients = [None] * (self.depth+1)
        self.globalcost = [None] * (self.depth+1)
        self.globalgradients = [None] * (self.depth+1)
        if self.totalupdatebool:
            self.totalcost = [None] * (self.depth+1)
            self.totalgradients = [None] * (self.depth+1)
        
        #params to update and inputs initialization
        paramstot = []
        paramsenc = []
        self.inputs = [None] * (self.depth+1)
        
        if self.input is not None:
            self.inputs[0] = [self.input]
        else:
            self.inputs[0] = []
        
        offset = 0
        for i in range(self.depth):
            
            if auxin_size[i] is None:
                offset +=1
                param = [inputprec, None, in_sizeprec, auxin_size[i], self.n_hid[i],\
                    False, self.tie_weights, self.hid_fn, self.reconstruction_cost_function,False]
            else:
                param = [inputprec, self.auxinput[i-offset], in_sizeprec, auxin_size[i], self.n_hid[i],\
                    False, self.tie_weights, self.hid_fn, self.reconstruction_cost_function,False]

            dict_params = dict(ignore_missing = self.ignore_missing,
                    reconstruct_missing = self.reconstruct_missing,
                    corruption_pattern = self.corruption_pattern)
            
            print '\tLayer init= ', i+1
            self.daaig[i] = DAAig(*param, **dict_params)
            
            # method input, outputs and parameters update
            if i:
                self.inputs[i] = copy.copy(self.inputs[i-1])
            if auxin_size[i] is not None:
                self.inputs[i] += [self.daaig[i].idx_list,self.auxinput[i-offset]]
            
            noisyout = []
            if inputprec is not None:
                noisyout += [self.daaig[i].noisy_input]
            if auxin_size[i] is not None:
                noisyout += [self.daaig[i].noisy_auxinput]
            
            paramstot += self.daaig[i].params
            
            # save the costs
            self.localcost[i] = self.daaig[i].noise.cost
            self.globalcost[i] = self.daaig[i].noise.cost
            if self.totalupdatebool:
                if i:
                    self.totalcost[i] = self.totalcost[i-1] + self.daaig[i].noise.cost
                else:
                    self.totalcost[i] = self.daaig[i].noise.cost
            
            if self.regularize:
                if i:
                    self.regularizationenccost[i] = self.regularizationenccost[i-1]+self.daaig[i-1].regularizationenc
                else:
                    self.regularizationenccost[i] = 0
                
                self.localcost[i] += self.daaig[i].regularization
                self.globalcost[i] += self.regularizationenccost[i]
                if self.totalupdatebool:
                    self.totalcost[i] += self.daaig[i].regularization
            
            self.localgradients[i] = T.grad(self.localcost[i], self.daaig[i].params)
            self.globalgradients[i] = T.grad(self.globalcost[i], self.daaig[i].params+paramsenc)
            if self.totalupdatebool:
                self.totalgradients[i] = T.grad(self.totalcost[i], paramstot)
            
            #create the updates dictionnaries
            local_grads = dict((j, j - self.unsup_lr * g) for j,g in zip(self.daaig[i].params,self.localgradients[i]))
            global_grads = dict((j, j - self.unsup_lr * g)\
                    for j,g in zip(self.daaig[i].params+paramsenc,self.globalgradients[i]))
            if self.totalupdatebool:
                total_grads = dict((j, j - self.unsup_lr * g) for j,g in zip(paramstot,self.totalgradients[i]))
            
            # method declaration
            self.localupdate[i] = theano.Method(self.inputs[i],self.localcost[i],local_grads)
            self.globalupdate[i] = theano.Method(self.inputs[i],self.globalcost[i],global_grads)
            if self.totalupdatebool:
                self.totalupdate[i] = theano.Method(self.inputs[i],self.totalcost[i],total_grads)
            #
            if self.debugmethod:
                self.representation[i] = theano.Method(self.inputs[i],self.daaig[i].clean.hidden_activation)
                self.reconstruction[i] = theano.Method(self.inputs[i],self.daaig[i].clean.rec)
                self.validate[i] =theano.Method(self.inputs[i], [self.daaig[i].clean.cost, self.daaig[i].clean.rec])
                self.noisyinputs[i] =theano.Method(self.inputs[i], noisyout)
                self.compute_localcost[i] = theano.Method(self.inputs[i],self.localcost[i])
                self.compute_localgradients[i] = theano.Method(self.inputs[i],self.localgradients[i])
                self.compute_globalcost[i] = theano.Method(self.inputs[i],self.globalcost[i])
                self.compute_globalgradients[i] = theano.Method(self.inputs[i],self.globalgradients[i])
                if self.totalupdatebool:
                    self.compute_totalcost[i] = theano.Method(self.inputs[i],self.totalcost[i])
                    self.compute_totalgradients[i] = theano.Method(self.inputs[i],self.totalgradients[i])
            #
            
            paramsenc += self.daaig[i].paramsenc
            inputprec = self.daaig[i].clean.hidden
            in_sizeprec = self.n_hid[i]
        
        # supervised layer
        print '\tLayer supervised init'
        self.inputs[-1] = copy.copy(self.inputs[-2])+[self.target]
        self.daaig[-1] = LogRegN(in_sizeprec,self.n_out,inputprec,self.target)
        paramstot += self.daaig[-1].params
        
        if self.regularize:
            self.localcost[-1] = self.daaig[-1].regularized_cost
            self.globalcost[-1] = self.daaig[-1].regularized_cost + self.regularizationenccost[-1]
        else:
            self.localcost[-1] = self.daaig[-1].unregularized_cost
            self.globalcost[-1] = self.daaig[-1].unregularized_cost
        
        if self.totalupdatebool:
            self.totalcost[-1] = [self.totalcost[-2], self.globalcost[-1]]
        
        self.localgradients[-1] = T.grad(self.localcost[-1], self.daaig[-1].params)
        self.globalgradients[-1] = T.grad(self.globalcost[-1], self.daaig[-1].params+paramsenc)
        if self.totalupdatebool:
            self.totalgradients[-1] = [T.grad(self.totalcost[-2], paramstot) ,\
                    T.grad(self.globalcost[-1], paramstot) ]
        
        local_grads = dict((j, j - self.sup_lr * g) for j,g in zip(self.daaig[-1].params,self.localgradients[-1]))
        global_grads = dict((j, j - self.sup_lr * g)\
                for j,g in zip(self.daaig[-1].params+paramsenc,self.globalgradients[-1]))
        if self.totalupdatebool:
            total_grads = dict((j, j - self.unsup_lr * g1 - self.sup_lr * g2)\
                    for j,g1,g2 in zip(paramstot,self.totalgradients[-1][0],self.totalgradients[-1][1]))
        
        self.localupdate[-1] = theano.Method(self.inputs[-1],self.localcost[-1],local_grads)
        self.globalupdate[-1] = theano.Method(self.inputs[-1],self.globalcost[-1],global_grads)
        if self.totalupdatebool:
            self.totalupdate[-1] = theano.Method(self.inputs[-1],self.totalcost[-1],total_grads)

        totallocal_grads={}
        for k in range(self.depth):
            totallocal_grads.update(dict((j, j - self.unsup_lr * g) for j,g in
                    zip(self.daaig[k].params,self.localgradients[k])))
        totallocal_grads.update(dict((j, j - self.sup_lr * g) for j,g in zip(self.daaig[-1].params,self.localgradients[-1])))
        self.totallocalupdate = theano.Method(self.inputs[-1],self.localcost,totallocal_grads)
        
        self.classify = theano.Method(self.inputs[-2],self.daaig[-1].argmax_standalone)
        self.NLL = theano.Method(self.inputs[-1],self.daaig[-1]._xent)

        
        if self.debugmethod:
            self.compute_localcost[-1] = theano.Method(self.inputs[-1],self.localcost[-1])
            self.compute_localgradients[-1] = theano.Method(self.inputs[-1],self.localgradients[-1])
            self.compute_globalcost[-1] = theano.Method(self.inputs[-1],self.globalcost[-1])
            self.compute_globalgradients[-1] = theano.Method(self.inputs[-1],self.globalgradients[-1])
            if self.totalupdatebool:
                self.compute_totalcost[-1] = theano.Method(self.inputs[-1],self.totalcost[-1])
                self.compute_totalgradients[-1] =\
                        theano.Method(self.inputs[-1],self.totalgradients[-1][0]+self.totalgradients[-1][1])
    
    def _instance_initialize(self,inst,unsup_lr = 0.1, sup_lr = 0.01, reg_coef = 0,
                                noise_level = 0 , noise_level_group = 0, seed = 1, alloc = True,**init):
        super(StackedDAAig, self)._instance_initialize(inst, **init)
        
        inst.unsup_lr = unsup_lr
        inst.sup_lr = sup_lr
        
        for i in range(self.depth):
            print '\tLayer = ', i+1
            inst.daaig[i].initialize(reg_coef = reg_coef, noise_level = noise_level,\
                    noise_level_group = noise_level_group, seed = seed + i, alloc = alloc)
        print '\tLayer supervised'
        inst.daaig[-1].initialize()
        if alloc:
            inst.daaig[-1].R = numpy.random.RandomState(seed+self.depth)
            # init the logreg weights
            inst.daaig[-1].w = inst.daaig[-1].R.uniform(size=inst.daaig[-1].w.shape,\
                                low = -1/numpy.sqrt(inst.daaig[-2].n_hid), high = 1/numpy.sqrt(inst.daaig[-2].n_hid))
        inst.daaig[-1].l1 = 0
        inst.daaig[-1].l2 = reg_coef #only l2 norm for regularisation to be consitent with the unsup regularisation