Mercurial > pylearn
view pylearn/shared/layers/logreg.py @ 1501:55534951dd91
Clean up import and remove deprecation warning.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Fri, 09 Sep 2011 10:53:46 -0400 |
parents | 519e82748a55 |
children |
line wrap: on
line source
"""Provides LogisticRegression """ import numpy import theano from theano.compile import shared from theano.tensor import nnet from pylearn.shared.layers.util import update_locals, add_logging class LogisticRegression(object): def __init__(self, input, w, b, params=[]): output=nnet.softmax(theano.dot(input, w)+b) l1=abs(w).sum() l2_sqr = (w**2).sum() argmax=theano.tensor.argmax(theano.dot(input, w)+b, axis=input.ndim-1) update_locals(self, locals()) @classmethod def new(cls, input, n_in, n_out, dtype=None, name=None): if dtype is None: dtype = input.dtype if name is None: name = cls.__name__ cls._debug('allocating params w, b', n_in, n_out, dtype) w = shared(numpy.zeros((n_in, n_out), dtype=dtype), name='%s.w'%name) b = shared(numpy.zeros((n_out,), dtype=dtype), name='%s.b'%name) return cls(input, w, b, params=[w,b]) def nll(self, target): """Return the negative log-likelihood of the prediction of this model under a given target distribution. Passing symbolic integers here means 1-hot. WRITEME """ return nnet.categorical_crossentropy(self.output, target) def errors(self, target): """Return a vector of 0s and 1s, with 1s on every line that was mis-classified. """ if target.ndim != self.argmax.ndim: raise TypeError('target should have the same shape as self.argmax', ('target', target.type, 'argmax', self.argmax.type)) if target.dtype.startswith('int'): return theano.tensor.neq(self.argmax, target) else: raise NotImplementedError() add_logging(LogisticRegression)