changeset 754:390d8c5a1fee

merge
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 02 Jun 2009 20:21:35 -0400
parents 0eee6693f149 (current diff) 84d22b7d835a (diff)
children 8447bc9bb2d4 a7dc8b28f4bc
files pylearn/algorithms/sgd.py
diffstat 5 files changed, 139 insertions(+), 49 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/sandbox/DAA_inputs_groups.py	Tue Jun 02 20:21:21 2009 -0400
+++ b/pylearn/algorithms/sandbox/DAA_inputs_groups.py	Tue Jun 02 20:21:35 2009 -0400
@@ -1,18 +1,20 @@
 import numpy
 import theano
-import os, copy
+import copy
 
 from theano import tensor as T
 from theano.compile import module
-from theano.tensor.nnet import sigmoid
 
 from pylearn.sandbox.scan_inputs_groups import scaninputs, scandotdec, scandotenc, scannoise, scanbiasdec, \
-        scanmaskenc,scanmaskdec, fill_missing_with_zeros, mask_gradient
+        scanmaskenc,scanmaskdec, FillMissing, mask_gradient
 
-from pylearn.algorithms import cost
 from pylearn.algorithms.logistic_regression import LogRegN
 
+# used to initialize containers
+class ScratchPad:
+    pass
 
+# regularisation utils:-------------------------------------------
 def lnorm(param, type='l2'):
     if type == 'l1':
         return T.sum(T.abs(param))
@@ -26,19 +28,40 @@
         rcost += lnorm(param, type)
     return rcost
 
-
+# activations utils:----------------------------------------------
 def sigmoid_act(x):
     return theano.tensor.nnet.sigmoid(x)
 
 def tanh_act(x):
-    return (theano.tensor.tanh((x-0.5)*2))/2.0+0.5
+    return theano.tensor.tanh(x)
+
+# costs utils:---------------------------------------------------
+
+# in order to fix numerical instability of the cost and gradient calculation for the cross entropy we calculate it
+# with the following functions direclty from the activation:
+
+def sigmoid_cross_entropy(target, output_act, mean_axis, sum_axis):
+    XE =-target * T.log(1 + T.exp(-output_act)) + (1 - target) * (- T.log(1 + T.exp(output_act)))
+    return -T.mean(T.sum(XE, axis=sum_axis),axis=mean_axis)
 
-def softsign_act(x):
-    return theano.sandbox.softsign.softsign(x)
+def tanh_cross_entropy(target, output_act, mean_axis, sum_axis):
+    XE =-(target+1)/2.0 * T.log(1 + T.exp(-2 * output_act)) + \
+            (1 - (target+1)/2.0) * (- T.log(1 + T.exp(2 * output_act)))
+    return -T.mean(T.sum(XE, axis=sum_axis),axis=mean_axis)
 
-class ScratchPad:
-    pass
+def cross_entropy(target, output_act, act, mean_axis=0, sum_axis=1):
+    if act == 'sigmoid_act':
+        return sigmoid_cross_entropy(target, output_act, mean_axis, sum_axis)
+    if act == 'tanh_act':
+        return tanh_cross_entropy(target, output_act, mean_axis, sum_axis)
+    assert False
 
+def quadratic(target, output, act, axis = 1):
+    return pylearn.algorithms.cost.quadratic(target, output, axis)
+
+
+
+# DAAig module----------------------------------------------------------------
 class DAAig(module.Module):
     """De-noising Auto-encoder
     """
@@ -46,24 +69,36 @@
     def __init__(self, input = None, auxinput = None,
                 in_size=None, auxin_size= None, n_hid=1,
                 regularize = False, tie_weights = False, hid_fn = 'sigmoid_act',
-                reconstruction_cost_function=cost.cross_entropy, interface = True,
-                ignore_missing=False,
+                reconstruction_cost_function='cross_entropy', interface = True,
+                ignore_missing=None, reconstruct_missing=False,
+                corruption_pattern=None,
                 **init):
         """
         :param regularize: WRITEME
         :param tie_weights: WRITEME
         :param hid_fn: WRITEME
         :param reconstruction_cost: Should return one cost per example (row)
-        :param ignore_missing: if True, the input will be scanned in order to
-            detect missing values, and these values will be replaced by zeros.
-            Also, the reconstruction cost's gradient will be computed only on
-            non missing components.
-            If False, the presence of missing values may cause crashes or other
+        :param ignore_missing: if not None, the input will be scanned in order
+            to detect missing values, and these values will be replaced. Also,
+            the reconstruction cost's gradient will be computed only on non
+            missing components. The value of this parameter indicates how to
+            replace missing values:
+                - some numpy.ndarray: value of this array at the same index
+                - a constant: this same value everywhere
+            If None, the presence of missing values may cause crashes or other
             weird and unexpected behavior.
             Please note that this option only affects the permanent input, not
             auxilary ones (that should never contain missing values). In fact,
             in the current implementation, auxiliary inputs cannot be used when
             this option is True.
+        :param corruption_pattern: if not None, may specify a particular way to
+        corrupt the input with missing values. Valid choices are:
+            - 'by_pair': consider that features are given as pairs, and corrupt
+            (or not) the whole pair instead of considering them independently.
+            Elements in a pair are not consecutive, instead they are assumed to
+            be at distance (total number of features / 2) of each other.
+        :param reconstruct_missing: if True, then the reconstruction cost on
+        missing inputs will be backpropagated. Otherwise, it will not.
         :todo: Default noise level for all daa levels
         """
         print '\t\t**** DAAig.__init__ ****'
@@ -82,18 +117,25 @@
         self.n_hid = n_hid
         self.regularize = regularize
         self.tie_weights = tie_weights
-        self.reconstruction_cost_function = reconstruction_cost_function
         self.interface = interface
         self.ignore_missing = ignore_missing
+        self.reconstruct_missing = reconstruct_missing
+        self.corruption_pattern = corruption_pattern
         
-        assert hid_fn in ('sigmoid_act','tanh_act','softsign_act')
+        
+        assert hid_fn in ('sigmoid_act','tanh_act')
         self.hid_fn = eval(hid_fn)
+        self.hid_name = hid_fn
+        
+        assert reconstruction_cost_function in ('cross_entropy','quadratic')
+        self.reconstruction_cost_function = eval(reconstruction_cost_function)
+        self.reconstruction_cost_function_name = reconstruction_cost_function
         
         ### DECLARE MODEL VARIABLES and default
         self.input = input
-        if self.ignore_missing and self.input is not None:
-            no_missing = fill_missing_with_zeros(self.input)
-            self.input = no_missing[0]  # Missing values replaced by zeros.
+        if self.ignore_missing is not None and self.input is not None:
+            no_missing = FillMissing(self.ignore_missing)(self.input)
+            self.input = no_missing[0]  # With missing values replaced.
             self.input_missing_mask = no_missing[1] # Missingness pattern.
         else:
             self.input_missing_mask = None
@@ -131,7 +173,8 @@
             self.noisy_input = self.corrupt_input()
         if self.auxinput is not None:
             self.noisy_idx_list , self.noisy_auxinput = \
-                scannoise(self.idx_list,self.auxinput,self.noise_level,self.noise_level_group)
+                scannoise(self.idx_list, self.auxinput,self.noise_level,
+                        self.noise_level_group)
         
         self.noise = ScratchPad()
         self.clean = ScratchPad()
@@ -152,7 +195,8 @@
         container.hidden = self.hid_fn(container.hidden_activation)
         self.define_propdown(container, idx_list , auxinput)
         container.rec = self.hid_fn(container.rec_activation)
-        if self.ignore_missing and self.input is not None:
+        if (self.ignore_missing is not None and self.input is not None and not
+                self.reconstruct_missing):
             # Apply mask to gradient to ensure we do not backpropagate on the
             # cost computed on missing inputs (that were replaced with zeros).
             container.rec = mask_gradient(container.rec,
@@ -212,11 +256,15 @@
     
     # DEPENDENCY: define_behavioural, define_regularization
     def define_cost(self, container):
-        container.reconstruction_cost = self.reconstruction_costs(container.rec)
+        if self.reconstruction_cost_function_name == 'cross_entropy':
+            container.reconstruction_cost = self.reconstruction_costs(container.rec_activation)
+        else:
+            container.reconstruction_cost = self.reconstruction_costs(container.rec)
         # TOTAL COST
-        container.cost = container.reconstruction_cost
         if self.regularize: #if stacked don't merge regularization and cost here but in the stackeddaaig module
             container.cost = container.cost + self.regularization
+        else:
+            container.cost = container.reconstruction_cost
     
     # DEPENDENCY: define_cost
     def define_params(self):
@@ -263,15 +311,28 @@
         self.validate = theano.Method(listin, [self.clean.cost, self.clean.rec])
     
     def corrupt_input(self):
-        return self.random.binomial(T.shape(self.input), 1, 1 - self.noise_level) * self.input
+        if self.corruption_pattern is None:
+            mask = self.random.binomial(T.shape(self.input), 1, 1 - self.noise_level)
+        elif self.corruption_pattern == 'by_pair':
+            shape = T.shape(self.input)
+            scale = numpy.ones(2)
+            scale[1] = 2
+            shape = shape / scale
+            mask = self.random.binomial(shape, 1, 1 - self.noise_level)
+            mask = T.hstack((mask, mask))
+        else:
+            raise ValueError('Unknown value for corruption_pattern: %s'
+                    % self.corruption_pattern)
+        return mask * self.input
     
     def reconstruction_costs(self, rec):
         if (self.input is not None) and (self.auxinput is not None):
-            return self.reconstruction_cost_function(T.join(1,self.input,scaninputs(self.idx_list,self.auxinput)), rec)
+            return self.reconstruction_cost_function(T.join(1,self.input,scaninputs(self.idx_list,self.auxinput)),\
+                    rec, self.hid_name)
         if self.input is not None:
-            return self.reconstruction_cost_function(self.input, rec)
+            return self.reconstruction_cost_function(self.input, rec, self.hid_name)
         if self.auxinput is not None:
-            return self.reconstruction_cost_function(scaninputs(self.idx_list,self.auxinput), rec)
+            return self.reconstruction_cost_function(scaninputs(self.idx_list,self.auxinput), rec, self.hid_name)
         # All cases should be covered above. If not, something is wrong!
         assert False
     
@@ -338,9 +399,10 @@
     def __init__(self, depth = 1, input = T.dmatrix('input'), auxinput = [None],
                 in_size = None, auxin_size = [None], n_hid = [1],
                 regularize = False, tie_weights = False, hid_fn = 'sigmoid_act',
-                reconstruction_cost_function=cost.cross_entropy,
+                reconstruction_cost_function='cross_entropy',
                 n_out = 2, target = None, debugmethod = False, totalupdatebool=False,
-                ignore_missing=False,
+                ignore_missing=None, reconstruct_missing=False,
+                corruption_pattern=None,
                 **init):
         
         super(StackedDAAig, self).__init__()
@@ -367,6 +429,8 @@
         self.debugmethod = debugmethod
         self.totalupdatebool = totalupdatebool
         self.ignore_missing = ignore_missing
+        self.reconstruct_missing = reconstruct_missing
+        self.corruption_pattern = corruption_pattern
         
         # init for model construction
         inputprec = input
@@ -432,7 +496,9 @@
                 param = [inputprec, self.auxinput[i-offset], in_sizeprec, auxin_size[i], self.n_hid[i],\
                     False, self.tie_weights, self.hid_fn, self.reconstruction_cost_function,False]
 
-            dict_params = dict(ignore_missing = self.ignore_missing)
+            dict_params = dict(ignore_missing = self.ignore_missing,
+                    reconstruct_missing = self.reconstruct_missing,
+                    corruption_pattern = self.corruption_pattern)
             
             print '\tLayer init= ', i+1
             self.daaig[i] = DAAig(*param, **dict_params)
--- a/pylearn/algorithms/sgd.py	Tue Jun 02 20:21:21 2009 -0400
+++ b/pylearn/algorithms/sgd.py	Tue Jun 02 20:21:35 2009 -0400
@@ -69,7 +69,7 @@
     
     :returns: standard minimizer constructor f(args, cost, params, gradient=None)
     """
-    def f(args, cost, params, gradient=None, updates=None, auxout=None):
-        return StochasticGradientDescent(args, cost, params, gradient, stepsize,
+    def f(args, cost, params, gradients=None, updates=None, auxout=None):
+        return StochasticGradientDescent(args, cost, params, gradients=gradients, stepsize=stepsize,
                 updates=updates, auxout=auxout)
     return f
--- a/pylearn/algorithms/stopper.py	Tue Jun 02 20:21:21 2009 -0400
+++ b/pylearn/algorithms/stopper.py	Tue Jun 02 20:21:35 2009 -0400
@@ -1,3 +1,4 @@
+import time
 """Early stopping iterators
 
 The idea here is to supply early-stopping heuristics that can be used in the
@@ -65,12 +66,14 @@
         return ICML08Stopper(30*ntrain/batchsize,
                 ntrain/batchsize, 0.96, 2.0, 100000000)
 
-    def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit):
+    def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit, hard_time_limit=None):
         self.initial_wait = i_wait
         self.set_score_interval = v_int
         self.min_improvement = min_improvement
         self.patience = patience
         self.hard_limit = hard_limit
+        self.hard_limit_seconds = hard_time_limit
+        self.start_time = time.time()
 
         self.best_score = float('inf')
         self.best_iter = -1
@@ -97,7 +100,8 @@
 
         starting = self.iter < self.initial_wait
         waiting = self.iter < (self.patience * self.best_iter)
-        if starting or waiting:
+        times_up = (time.time() - self.start_time) > self.hard_limit_seconds if self.hard_limit_seconds != None else False
+        if (starting or waiting) and not times_up:
             # continue to iterate
             self.iter += 1
             if self.iter == self.hard_limit:
--- a/pylearn/datasets/norb_small.py	Tue Jun 02 20:21:21 2009 -0400
+++ b/pylearn/datasets/norb_small.py	Tue Jun 02 20:21:35 2009 -0400
@@ -68,8 +68,8 @@
     path = Paths()
 
     def __init__(self, ntrain=19440, nvalid=4860, ntest=24300, 
-               downsample_amt=1, seed=1, normalize=True,
-               mode='stereo', dtype='float64'):
+               downsample_amt=1, seed=1, normalize=False,
+               mode='stereo', dtype='int8'):
 
         self.n_classes = 5
         self.nsamples = 24300
--- a/pylearn/sandbox/scan_inputs_groups.py	Tue Jun 02 20:21:21 2009 -0400
+++ b/pylearn/sandbox/scan_inputs_groups.py	Tue Jun 02 20:21:35 2009 -0400
@@ -567,23 +567,37 @@
     """
     Given an input, output two elements:
         - a copy of the input where missing values (NaN) are replaced by some
-        constant (zero by default)
+        other value (zero by default)
         - a mask of the same size and type as input, where each element is zero
         iff the corresponding input is missing
-    Currently, the gradient is computed as if the input value was really zero.
-    It may be safer to replace the gradient w.r.t. missing values with either
-    zeros or missing values (?).
+    The 'fill_with' parameter may either be:
+        - a scalar: all missing values are replaced with this value
+        - a Numpy array: a missing value is replaced by the value in this array
+        at the same position (ignoring the first k dimensions if 'fill_with'
+        has k less dimensions than the input)
+    Currently, the gradient is computed as if the input value was really what
+    it was replaced with. It may be safer to replace the gradient w.r.t.
+    missing values with either zeros or missing values (?).
     """
 
-    def __init__(self, constant_val=0):
+    def __init__(self, fill_with=0):
         super(Op, self).__init__()
-        self.constant_val = constant_val
+        self.fill_with = fill_with
+        self.fill_with_is_array = isinstance(self.fill_with, numpy.ndarray)
 
     def __eq__(self, other):
-        return type(self) == type(other) and (self.constant_val == other.constant_val)
+        return (type(self) == type(other) and
+                self.fill_with_is_array == other.fill_with_is_array and
+                ((self.fill_with_is_array and 
+                    (self.fill_with == other.fill_with).all()) or
+                    self.fill_with == other.fill_with))
 
-	def __hash__(self):
-		return hash(type(self))^hash(self.constant_val)
+    def __hash__(self):
+        if self.fill_with_is_array:
+            fill_hash = self.fill_with.__hash__()
+        else:
+            fill_hash = hash(self.fill_with)
+        return hash(type(self))^hash(self.fill_with_is_array)^fill_hash
 	
     def make_node(self, input):
         return Apply(self, [input], [input.type(), input.type()])
@@ -595,9 +609,15 @@
         mask = output_storage[1]
         mask[0] = numpy.ones(input.shape)
         mask = mask[0]
+        if self.fill_with_is_array:
+            ignore_k = len(out.shape) - len(self.fill_with.shape)
+            assert ignore_k >= 0
         for (idx, v) in numpy.ndenumerate(out):
             if numpy.isnan(v):
-                out[idx] = self.constant_val
+                if self.fill_with_is_array:
+                    out[idx] = self.fill_with[idx[ignore_k:]]
+                else:
+                    out[idx] = self.fill_with
                 mask[idx] = 0
 
     def grad(self, inputs, (out_grad, mask_grad, )):