changeset 178:4090779e39a9

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 13 May 2008 15:12:20 -0400
parents 69759976b3ac
children 9911d2cc3c01
files mlp.py stopper.py
diffstat 2 files changed, 142 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/mlp.py	Tue May 13 15:11:47 2008 -0400
+++ b/mlp.py	Tue May 13 15:12:20 2008 -0400
@@ -10,6 +10,21 @@
 from nnet_ops import *
 import math
 
+def sum_l2_cost(*params):
+    p = params[0]
+    rval = t.sum(p*p)
+    for p in params[1:]:
+        rval = rval + t.sum(p*p)
+    return rval
+
+def activation(w, b, v, c, x):
+    return t.dot(t.tanh(t.dot(x, w) + b), v) + c
+def nll(w, b, v, c, x, y):
+    return  crossentropy_softmax_1hot(prediction(w, b, v, c, x), y)[0]
+def output(w, b, v, c, x, y):
+    return  crossentropy_softmax_1hot(prediction(w, b, v, c, x), y)[1]
+
+
 
 class OneHiddenLayerNNetClassifier(OnlineGradientTLearner):
     """
@@ -67,7 +82,6 @@
        - 'regularization_term'
 
     """
-
     def __init__(self,n_hidden,n_classes,learning_rate,max_n_epochs,L2_regularizer=0,init_range=1.,n_inputs=None,minibatch_size=None):
         self._n_inputs = n_inputs
         self._n_outputs = n_classes
@@ -142,6 +156,25 @@
         self._n_epochs +=1
         return self._n_epochs>=self._max_n_epochs
 
+    def updateMinibatch(self,minibatch):
+        # make sure all required fields are allocated and initialized
+        self.allocate(minibatch)
+        input_attributes = self.names2attributes(self.updateMinibatchInputAttributes())
+        input_fields = minibatch(*self.updateMinibatchInputFields())
+        print 'input attributes', input_attributes
+        print 'input fields', input_fields
+        results = self.update_minibatch_function(*(input_attributes+input_fields))
+        print 'output attributes', self.updateMinibatchOutputAttributes()
+        print 'results', results
+        self.setAttributes(self.updateMinibatchOutputAttributes(),
+                           results)
+
+        if 0:
+            print 'n0', self.names2OpResults(self.updateMinibatchOutputAttributes()+ self.updateMinibatchInputFields())
+            print 'n1', self.names2OpResults(self.updateMinibatchOutputAttributes())
+            print 'n2', self.names2OpResults(self.updateEndInputAttributes())
+            print 'n3', self.names2OpResults(self.updateEndOutputAttributes())
+
 class MLP(MinibatchUpdatesTLearner):
     """
     Implement a feedforward multi-layer perceptron, with or without L1 and/or L2 regularization.
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stopper.py	Tue May 13 15:12:20 2008 -0400
@@ -0,0 +1,108 @@
+"""Early stopping iterators
+
+The idea here is to supply early-stopping heuristics that can be used in the
+form:
+
+    stopper = SomeEarlyStopper()
+
+    for i in stopper():
+        # train from data
+        if i.set_score:
+            i.score = validation_score
+
+
+So far I only have one heuristic, so maybe this won't scale.
+"""
+
+class Stopper(object):
+
+    def train(self, data, update_rows_fn, update, validate, save=None):
+        """Return the best model trained on data
+
+        Parameters:
+        data - a thing that accepts getitem(<list of int64>), or a tuple of such things
+        update_rows_fn - fn : int --> <list or tensor of int>
+        update - fn: update an internal model from elements of data
+        validate - fn: evaluate an internal model based on elements of data
+        save - fn: return a copy of the internal model
+
+        The body of this function exhausts the <self> iterator, and trains a
+        model using early stopping in the process.
+        """
+
+        best = None
+        for stp in self:
+            i = stp.iter
+
+            # call update on some training set rows
+            t_rows = update_rows_fn(i)
+            if isinstance(data, (tuple, list)):
+                update(*[d[t_rows] for d in data])
+            else:
+                update(data[t_rows])
+
+            if stp.set_score:
+                stp.score = validate()
+                if (stp.score < stp.best_score) and save:
+                    best = save()
+        return best
+
+
+class ICML08Stopper(Stopper):
+    @staticmethod
+    def icml08(ntrain, batchsize):
+        """Some setting similar to what I used for ICML08 submission"""
+        #TODO: what did I actually use? put that in here.
+        return ICML08Stopper(30*ntrain/batchsize,
+                ntrain/batchsize, 0.96, 2.0, 100000000)
+
+    def __init__(self, i_wait, v_int, min_improvement, patience, hard_limit):
+        self.initial_wait = i_wait
+        self.set_score_interval = v_int
+        self.min_improvement = min_improvement
+        self.patience = patience
+        self.hard_limit = hard_limit
+
+        self.best_score = float('inf')
+        self.best_iter = -1
+        self.iter = -1
+
+        self.set_score = False
+        self.score = None
+
+    def __iter__(self):
+        return self
+
+    E_set_score = 'when iter.set_score is True, caller must assign a score to iter.score'
+    def next(self):
+        if self.set_score: #left over from last time
+            if self.score is None:
+                raise Exception(ICML08Stopper.E_set_score)
+            if self.score < (self.best_score * self.min_improvement):
+                (self.best_score, self.best_iter) = (self.score, self.iter)
+            self.score = None #un-set it
+
+
+        starting = self.iter < self.initial_wait
+        waiting = self.iter < (self.patience * self.best_iter)
+        if starting or waiting:
+            # continue to iterate
+            self.iter += 1
+            if self.iter == self.hard_limit:
+                raise StopIteration
+            self.set_score = (self.iter % self.set_score_interval == 0)
+            return self
+
+        raise StopIteration
+
+
+class NStages(ICML08Stopper):
+    """Run for a fixed number of steps, checking validation set every so
+    often."""
+    def __init__(self, hard_limit, v_int):
+        ICML08Stopper.__init__(self, hard_limit, v_int, 1.0, 1.0, hard_limit)
+
+    #TODO: could optimize next() function. Most of what's in ICML08Stopper.next()
+    #is not necessary
+
+