# HG changeset patch # User James Bergstra # Date 1238430361 14400 # Node ID 719194960d180a03e5a8349452c0460077f46a2a # Parent d69e668ab904dd2aa7596e63a1753d49a98accc0# Parent 070a7d68d3a1aba28a8a7463ae3ef0df94facfd5 merge diff -r d69e668ab904 -r 719194960d18 pylearn/algorithms/daa.py --- a/pylearn/algorithms/daa.py Mon Mar 30 12:25:42 2009 -0400 +++ b/pylearn/algorithms/daa.py Mon Mar 30 12:26:01 2009 -0400 @@ -2,11 +2,13 @@ import theano from theano import tensor as T from theano.tensor import nnet as NN +from theano.tensor.deprecated import rmodule + import numpy as N -from pylearn import cost as cost +from pylearn.algorithms import cost -class DenoisingAA(T.RModule): +class DenoisingAA(rmodule.RModule): """De-noising Auto-encoder WRITEME diff -r d69e668ab904 -r 719194960d18 pylearn/algorithms/rbm.py --- a/pylearn/algorithms/rbm.py Mon Mar 30 12:25:42 2009 -0400 +++ b/pylearn/algorithms/rbm.py Mon Mar 30 12:26:01 2009 -0400 @@ -1,6 +1,7 @@ import sys, copy import theano from theano import tensor as T +from theano.tensor.deprecated import rmodule from theano.tensor.nnet import sigmoid from theano.compile import module from theano import printing, pprint @@ -11,9 +12,8 @@ from ..datasets import make_dataset from .minimizer import make_minimizer from .stopper import make_stopper -from ..dbdict.experiment import subdict -class RBM(T.RModule): +class RBM(rmodule.RModule): # is it really necessary to pass ALL of these ? - GD def __init__(self, @@ -76,7 +76,7 @@ def train_rbm(state, channel=lambda *args, **kwargs:None): - dataset = make_dataset(**subdict_copy(state, prefix='dataset_')) + dataset = make_dataset(**state.dataset) train = dataset.train rbm_module = RBM( diff -r d69e668ab904 -r 719194960d18 pylearn/algorithms/sandbox/__init__.py diff -r d69e668ab904 -r 719194960d18 pylearn/algorithms/sandbox/cost.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/algorithms/sandbox/cost.py Mon Mar 30 12:26:01 2009 -0400 @@ -0,0 +1,154 @@ +""" +Cost functions. + +@note: All of these functions return one cost per example. So it is your +job to perform a tensor.sum over the individual example losses. +""" + +import theano as T +from theano import tensor, scalar +import numpy + +class UndefinedGradient(Exception): + """ + Raised by UndefinedGradientOp to indicate that the gradient is undefined mathematically. + """ + pass +from theano import gof +class UndefinedGradientOp(gof.Op): + def perform(self, x=None): + if x is not None: raise UndefinedGradient(x) + else: raise UndefinedGradient(x) +undefined_gradient = UndefinedGradientOp() + +class LogFactorial(scalar.UnaryScalarOp): + """ + Compute log x!. + @todo: Rewrite so that it uses INTs not FLOATs. + @todo: Move this to Theano. + @todo: This function is slow, probably want to cache the values. + """ + @staticmethod + def st_impl(x): + if not isinstance(x, int) and not isinstance(x, long): + raise TypeError('type(x) = %s, must be int or long' % type(x)) + if x == 0.0: + return 0.0 + v = 0.0 + for i in range(x): + v += numpy.log(x) + return v + def impl(self, x): + return LogFactorial.st_impl(x) + def grad(self, (x,), (gz,)): + undefined_gradient(self) +# def grad(self, (x,), (gz,)): +# raise NotImplementedError('gradient not defined over discrete values') +# return None +# return [gz * (1 + scalar.log(x))] +# def c_code(self, node, name, (x,), (z,), sub): +# if node.inputs[0].type in [scalar.float32, scalar.float64]: +# return """%(z)s = +# %(x)s == 0.0 +# ? 0.0 +# : %(x)s * log(%(x)s);""" % locals() +# raise NotImplementedError('only floatingpoint is implemented') +scalar_logfactorial = LogFactorial(scalar.upgrade_to_float, name='scalar_logfactoral') +logfactorial = tensor.Elemwise(scalar_logfactorial, name='logfactorial') + + +def poissonlambda(unscaled_output, doclen, beta_scale): + """ + A continuous parameter lambda_i which is the expected number of + occurence of word i in the document. Note how this must be positive, + and that is why Ranzato and Szummer (2008) use an exponential. + + Yoshua: I don't like exponentials to guarantee positivity. softplus + is numerically much better behaved (but you might want to try both + to see if it makes a difference). + + @todo: Maybe there are more sensible ways to set the beta_scale. + """ + beta = beta_scale * doclen + return beta * tensor.exp(unscaled_output) + +def nlpoisson(target, output, beta_scale=1, axis=0, sumloss=True, zerothreshold=0): + """ + The negative log Poisson regression probability. + From Ranzato and Szummer (2008). + + Output should be of the form Weight*code+bias, i.e. unsquashed. + NB this is different than the formulation in Salakhutdinov and Hinton + (2007), in which the output is softmax'ed and multiplied by the input + document length. That is also what Welling et. al (2005) do. It would + be useful to try the softmax, because it is more well-behaved. + + There is a beta term that is proportional to document length. We + are not sure what beta scale is used by the authors. We use 1 as + the default, but this value might be inappropriate. + For numerical reasons, Yoshua recommends choosing beta such that + the lambda is expected to be around 1 for words that have a non-zero count. + So he would take: + + beta = document_size / unique_words_per_document + + I am not sure the above math is correct, I need to talk to him. + + Yoshua notes that ``there is a x_i log(beta) term missing, if you + compare with eqn 2 (i.e., take the log). They did not include in + 3 because it does not depend on the parameters, so the gradient + wrt it would be 0. But if you really want log-likelihood it should + be included.'' If you want a true log-likelihood, you probably should + actually compute the derivative of the entire eqn 2. + + Axis is the axis along which we sum the target values, to obtain + the document length. + + If sumloss, we sum the loss along axis. + + If zerothreshold is non-zero, we threshold the loss: + If this target dimension is zero and beta * tensor.exp(output) + < zerothreshold, let this loss be zero. + + @todo: Include logfactorial term + """ +# from theano.printing import Print +# print dtype(target) # make sure dtype is int32 or int64 +# print target.dtype + doclen = tensor.sum(target, axis=axis) + lambdav = poissonlambda(output, doclen, beta_scale) + lossterms = lambdav - target*output + if sumloss: + return tensor.sum(lossterms, axis=axis) + else: + return lossterms +# return tensor.sum(beta * tensor.exp(output) - target*output + logfactorial(target), axis=axis) + + +#import numpy +#def nlpoisson_nontheano(target, output, beta_scale=1, axis=0): +# doclen = numpy.sum(target, axis=axis) +# print "doclen", doclen +# beta = beta_scale * doclen +# print "beta", beta +# print "exp", numpy.exp(output) +# print "beta * exp", beta * numpy.exp(output) +# print "x * y", target * output +# +# import theano.tensor as TT +# x = TT.as_tensor(target) +# o = logfactorial(x) +# f = T.function([],o) +# logf = f() +# print "log factorial(x)", logf +# print "beta * exp - dot + log factorial", beta * numpy.exp(output) - target*output + f() +# print "total loss", numpy.sum(beta * numpy.exp(output) - target*output + f(), axis=axis) +# +## return beta * numpy.exp(output) - numpy.dot(target, output) +## #+ logfactorial(target) +# +#import numpy +#target = numpy.array([0, 0, 1, 1, 2, 2, 100, 100]) +##output = numpy.array([0., 0.5, 1., 0.5, 2., 0.5, 100., 0.5]) +#output = numpy.array([0., 1, 1., 0, 1, 0, 5, 1]) +#nlpoisson_nontheano(target, output) diff -r d69e668ab904 -r 719194960d18 pylearn/algorithms/sandbox/test_cost.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/algorithms/sandbox/test_cost.py Mon Mar 30 12:26:01 2009 -0400 @@ -0,0 +1,53 @@ +import pylearn.algorithms.sandbox.cost as cost + +import unittest +import theano as T +import theano.tensor as TT +import numpy + +class T_logfactorial(unittest.TestCase): + def test(self): + x = TT.as_tensor(range(10)) + o = cost.logfactorial(x) + f = T.function([],o) + self.failUnless(numpy.all(f() - numpy.asarray([0., 0., 1.38629436, 3.29583687, 5.54517744, 8.04718956, 10.75055682, 13.62137104, 16.63553233, 19.7750212])) < 1e-5) + + def test_float(self): + """ + This should fail because we can't use floats in logfactorial + """ + x = TT.as_tensor([0.5, 2.7]) + o = cost.logfactorial(x) + f = T.function([],o) +# print repr(f()) + self.failUnless(numpy.all(f() == numpy.asarray([0., 0., 1.38629436, 3.29583687, 5.54517744, 8.04718956, 10.75055682, 13.62137104, 16.63553233, 19.7750212]))) + +class T_nlpoisson(unittest.TestCase): + def test(self): + target = TT.as_tensor([0, 0, 1, 1, 2, 2, 100, 100]) + output = TT.as_tensor([0., 1, 1., 0, 1, 0, 5, 1]) + o = cost.nlpoisson(target, output) + f = T.function([],o) + self.failUnless(f() - 33751.7816277 < 1e-5) + + def test_gradient(self): + target = TT.as_tensor([0, 0, 1, 1, 2, 2, 100, 100]) + output = TT.as_tensor([0., 1, 1., 0, 1, 0, 5, 1]) + loss = cost.nlpoisson(target, output) + (goutput) = TT.grad(loss, [output]) +# (goutput) = TT.grad(loss, [target]) + f = T.function([], goutput) + print f() + self.failUnless(numpy.all(f() - numpy.asarray([206., 559.96605666, 558.96605666, 205., 557.96605666, 204., 30473.11077513, 459.96605666] < 1e-5))) + + def test_gradient_fail(self): + target = TT.as_tensor([0, 0, 1, 1, 2, 2, 100, 100]) + output = TT.as_tensor([0., 1, 1., 0, 1, 0, 5, 1]) + loss = cost.nlpoisson(target, output) + (goutput) = TT.grad(loss, [target]) + f = T.function([], goutput) + print f() + self.failUnless(numpy.all(f() - numpy.asarray([206., 559.96605666, 558.96605666, 205., 557.96605666, 204., 30473.11077513, 459.96605666] < 1e-5))) + +if __name__ == '__main__': + unittest.main() diff -r d69e668ab904 -r 719194960d18 pylearn/algorithms/stacker.py --- a/pylearn/algorithms/stacker.py Mon Mar 30 12:25:42 2009 -0400 +++ b/pylearn/algorithms/stacker.py Mon Mar 30 12:26:01 2009 -0400 @@ -7,10 +7,11 @@ import theano from theano import tensor as T +from theano.tensor.deprecated import rmodule import sys import numpy as N -class Stacker(T.RModule): +class Stacker(rmodule.RModule): """ @note: Assumes some names in the layers: input, cost, lr, and update @todo: Maybe compile functions on demand, rather than immediately. diff -r d69e668ab904 -r 719194960d18 pylearn/datasets/MNIST.py --- a/pylearn/datasets/MNIST.py Mon Mar 30 12:25:42 2009 -0400 +++ b/pylearn/datasets/MNIST.py Mon Mar 30 12:26:01 2009 -0400 @@ -53,6 +53,14 @@ def full(): return train_valid_test() +#usefull for test, keep it +def first_10(): + return train_valid_test(ntrain=10, nvalid=10, ntest=10) + +#usefull for test, keep it +def first_100(): + return train_valid_test(ntrain=100, nvalid=100, ntest=100) + def first_1k(): return train_valid_test(ntrain=1000, nvalid=200, ntest=200) diff -r d69e668ab904 -r 719194960d18 pylearn/datasets/config.py --- a/pylearn/datasets/config.py Mon Mar 30 12:25:42 2009 -0400 +++ b/pylearn/datasets/config.py Mon Mar 30 12:26:01 2009 -0400 @@ -5,12 +5,14 @@ """ import os, sys -def env_get(key, default): +def env_get(key, default, key2 = None): + if key2 and os.getenv(key) is None: + key=key2 if os.getenv(key) is None: print >> sys.stderr, "WARNING: Environment variable", key, print >> sys.stderr, "is not set. Using default of", default return default if os.getenv(key) is None else os.getenv(key) def data_root(): - return env_get('PYLEARN_DATA_ROOT', os.getenv('HOME')+'/data') + return env_get('PYLEARN_DATA_ROOT', os.getenv('HOME')+'/data', 'DBPATH') diff -r d69e668ab904 -r 719194960d18 pylearn/external/wrap_libsvm.py --- a/pylearn/external/wrap_libsvm.py Mon Mar 30 12:25:42 2009 -0400 +++ b/pylearn/external/wrap_libsvm.py Mon Mar 30 12:26:01 2009 -0400 @@ -2,7 +2,6 @@ """ import numpy from ..datasets import make_dataset -from ..dbdict.experiment import subdict_copy # libsvm currently has no python installation instructions/convention. # @@ -55,8 +54,7 @@ This is the kind of function that dbdict-run can use. """ - dataset = make_dataset(**subdict_copy(state, 'dataset_')) - + dataset = make_dataset(**state.dataset) #libsvm needs stuff in int32 on a 32bit machine diff -r d69e668ab904 -r 719194960d18 pylearn/sandbox/test_speed.py --- a/pylearn/sandbox/test_speed.py Mon Mar 30 12:25:42 2009 -0400 +++ b/pylearn/sandbox/test_speed.py Mon Mar 30 12:26:01 2009 -0400 @@ -1,5 +1,5 @@ import numpy -from dataset import * +from pylearn.datasets import * from misc import * def test_speed(array, ds): print "test_speed", ds.__class__