view pylearn/algorithms/sparse_coding.py @ 1496:93b8373c6735

Prefix loggers with 'pylearn.' to ensure there is no conflict when using Pylearn code within another library
author Olivier Delalleau <delallea@iro>
date Mon, 22 Aug 2011 11:28:48 -0400
parents e88d7b7d53ed
children
line wrap: on
line source

import sys, logging, os
import numpy, PIL.Image, scipy.optimize
import theano


from theano import shared, function
import theano.tensor as TT
from theano.tensor import matrix, vector, scalar, dot, grad, switch, clip
floatX = theano.config.floatX

from pylearn.io.image_tiling import tile_raster_images
from pylearn.datasets import image_patches

def sample_codebook_prior(shape, rng, dtype=floatX):
    rval = numpy.asarray(rng.randn(*shape), dtype=dtype)
    for i, img in enumerate(rval):
        rval[i] = img / numpy.sqrt( (img**2).sum())
    print >> sys.stderr, "TODO: pick a codebook prior"
    return  rval

def numpy_project_onto_ball(X):
    """Return a copy of X with rows scaled to unit length
    """
    norms = numpy.sqrt((X**2).sum(axis=1))
    norms.shape = (X.shape[0], 1)
    return X / norms

def reproduce_olshausen():

    rng = numpy.random.RandomState(89234)
    ##
    # load the data into X
    ##

    #TODO: Make sure this new way of loading the data still gives the
    #      right results (check mean and variance!)
    X = numpy.asarray(image_patches.load_patches().train.x, dtype=floatX)
    img_shape = (20,20)
    PIL.Image.fromarray(
            tile_raster_images(X[:100], (20,20), (10,10), (1,1)), 
            'L').save('X100.png')

    X -= numpy.mean(X, axis=0)

    # Constants below are set up for a variance of .01
    X /= 10*numpy.std(X, axis=0)+1e-8

    batchsize=100
    tile_dims=(24,24)
    nZ=tile_dims[0]*tile_dims[1]

    #C = numpy.asarray((X.shape[0], nZ), dtype=floatX)

    ##
    # symbolic stuff
    ##
    sX = matrix()
    flat_C = vector()
    sC = flat_C.reshape((batchsize, nZ))
    sZ = matrix()

    sXpred = dot(sC, sZ)

    cost_X = ((sX - sXpred)**2).sum()
    cost_C = TT.log(1+(sC/.116)**2).sum()
    cost = 100.0 * cost_X + 2.2 * cost_C
    gC, gZ = grad(cost, [sC, sZ])

    cost_fn = function([sX, flat_C, sZ], cost)
    debug_fn = function([sX, flat_C, sZ], [cost, cost_X, cost_C])
    gC_fn = function([sX, flat_C, sZ], gC.flatten())

    gC_gZ_fn = function([sX, sC, sZ], [cost, gC,gZ])

    # sample some imgs from the dictionary prior
    Z = sample_codebook_prior((nZ, X.shape[1]), rng)


    # loop over the data by 100 examples at a time
    for j in xrange(1000):
        for i in xrange(10):
            offset = ((j*10+i)*batchsize) % len(X)
            Xi = X[offset:offset+batchsize]
            if len(Xi) != batchsize:
                continue

            #Ci = C[offset:offset+batchsize]

            # get the optimal C given X and Z by cg
            Ci = scipy.optimize.fmin_cg(
                    f=lambda c:cost_fn(Xi, c, Z),
                    x0=numpy.dot(Xi, Z.T).flatten(),
                    fprime=lambda c: gC_fn(Xi, c, Z),
                    maxiter=20,
                    gtol=1e-2,
                    )
            Ci.shape = (batchsize, nZ)

            print "j", j,
            print "i", i,
            print "Ci**2", (Ci**2).sum(),

            if 1: # use published algo

                Xi_residual = Xi - numpy.dot(Ci, Z)

                dZ = numpy.dot(Ci.T, Xi_residual)

                #print "solution", debug_fn(Xi, Ci.flatten(), Z)
                print "dXi**2", (Xi_residual**2).sum(),
                print "dZ**2", (dZ**2).sum()

                eta = 3e-1

                Z += eta * dZ
                Z = numpy_project_onto_ball(Z)
            if 0:
                # use joint optimization of Z and C 
                # This is a little faster, but not much.
                #
                # It is important not to change Z too much for any C
                # because small changes in Z can trigger large changes in the corresponding
                # optimal C for a given X.  That's a natural consequence of the active
                # sparsification.

                eta = 1e-3
                for k in xrange(3):
                    l, dC,dZ = gC_gZ_fn(Xi, Ci, Z)
                    print "k", k,
                    print "l", l, 
                    print "dCi**2", (dC**2).sum(),
                    print "dZ**2", (dZ**2).sum()
                    #Ci -= eta * dC
                    Z -= eta * dZ
                    Z = numpy_project_onto_ball(Z)

        hist,edges = numpy.histogram(abs(Ci))
        print "Hist"
        for h, e in zip(hist, edges):
            print '%.3f \t %.3f' % (e, h)

        if j < 100:
            PIL.Image.fromarray(
                    tile_raster_images(Z, (20,20), tile_dims, (1,1)), 
                    'L').save('Z_j=%i.png'%j)
        else:
            if (j % 10) == 0:
                PIL.Image.fromarray(
                        tile_raster_images(Z, (20,20), tile_dims, (1,1)), 
                        'L').save('Z_j=%i.png'%j)


if __name__ == '__main__':
    logging.basicConfig(stream=sys.stderr)
    logging.getLogger('pylearn.algorithms.sparse_coding.main').setLevel(logging.INFO)
    logging.getLogger('pylearn.algorithms.sparse_coding.main').info('hello')

    # load olshausen images
    reproduce_olshausen()