view pylearn/preprocessing/fft_whiten.py @ 1416:28b2f17991aa

80char
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 04 Feb 2011 16:01:45 -0500
parents 2eb98a740823
children
line wrap: on
line source

"""
Whitening algorithm used in Olshausen & Field, 1997.

If you want to use some other images, there are a number of
preprocessing steps you need to consider beforehand.  First, you
should make sure all images have approximately the same overall
contrast.  One way of doing this is to normalize each image so that
the variance of the pixels is the same (i.e., 1).  Then you will need
to prewhiten the images.  For a full explanation of whitening see

Olshausen BA, Field DJ (1997)  Sparse Coding with an Overcomplete
Basis Set: A Strategy Employed by V1?  Vision Research, 37: 3311-3325. 

Basically, to whiten an image of size NxN, you multiply by the filter
f*exp(-(f/f_0)^4) in the frequency domain, where f_0=0.4*N (f_0 is the
cutoff frequency of a lowpass filter that is combined with the
whitening filter).  Once you have preprocessed a number of images this
way, all the same size, then you should combine them into one big N^2
x M array, where M is the number of images.  Then rescale this array
so that the average image variance is 0.1 (the parameters in sparsenet
are set assuming that the data variance is 0.1).  Name this array
IMAGES, save it to a file for future use, and you should be off and
running.  The following Matlab commands should do this:


    N=image_size;
    M=num_images;

    [fx fy]=meshgrid(-N/2:N/2-1,-N/2:N/2-1);
    rho=sqrt(fx.*fx+fy.*fy);
    f_0=0.4*N;
    filt=rho.*exp(-(rho/f_0).^4);

    for i=1:M
    image=get_image;  % you will need to provide get_image
    If=fft2(image);
    imagew=real(ifft2(If.*fftshift(filt)));
    IMAGES(:,i)=reshape(imagew,N^2,1);
    end

    IMAGES=sqrt(0.1)*IMAGES/sqrt(mean(var(IMAGES)));

    save MY_IMAGES IMAGES

"""

import numpy as np
def whiten(X, f0=None, n=None):
    """
    :param X: a 3D tensor n_images x rows x cols
    :param f0: filter parameter (see docs)
    :param n: filter parameter (see docs)

    :returns: 3D tensor n_images x rows x cols of filtered images.
    """
    # May be mixed up with the size2 and size1s because matlab does things
    # differnetly
    R, C = X.shape[-2:]
    if R %2:
        raise NotImplementedError()
    if C %2:
        raise NotImplementedError()

    if f0 is None:
        f0 = .4 * min(R,C)
    if n is None:
        n = 4


    fx,fy = np.mgrid[-R/2:R/2, -C/2:C/2]
    rho=np.sqrt(fx**2 + fy**2)
    filt=rho * np.exp( - (rho/f0)**4)
    If = np.fft.fft2(X)
    imagew=np.real(np.fft.ifft2(If * np.fft.fftshift(filt)))
    assert imagew.shape == X.shape
    return imagew