Mercurial > pylearn
view pylearn/preprocessing/fft_whiten.py @ 1416:28b2f17991aa
80char
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 04 Feb 2011 16:01:45 -0500 |
parents | 2eb98a740823 |
children |
line wrap: on
line source
""" Whitening algorithm used in Olshausen & Field, 1997. If you want to use some other images, there are a number of preprocessing steps you need to consider beforehand. First, you should make sure all images have approximately the same overall contrast. One way of doing this is to normalize each image so that the variance of the pixels is the same (i.e., 1). Then you will need to prewhiten the images. For a full explanation of whitening see Olshausen BA, Field DJ (1997) Sparse Coding with an Overcomplete Basis Set: A Strategy Employed by V1? Vision Research, 37: 3311-3325. Basically, to whiten an image of size NxN, you multiply by the filter f*exp(-(f/f_0)^4) in the frequency domain, where f_0=0.4*N (f_0 is the cutoff frequency of a lowpass filter that is combined with the whitening filter). Once you have preprocessed a number of images this way, all the same size, then you should combine them into one big N^2 x M array, where M is the number of images. Then rescale this array so that the average image variance is 0.1 (the parameters in sparsenet are set assuming that the data variance is 0.1). Name this array IMAGES, save it to a file for future use, and you should be off and running. The following Matlab commands should do this: N=image_size; M=num_images; [fx fy]=meshgrid(-N/2:N/2-1,-N/2:N/2-1); rho=sqrt(fx.*fx+fy.*fy); f_0=0.4*N; filt=rho.*exp(-(rho/f_0).^4); for i=1:M image=get_image; % you will need to provide get_image If=fft2(image); imagew=real(ifft2(If.*fftshift(filt))); IMAGES(:,i)=reshape(imagew,N^2,1); end IMAGES=sqrt(0.1)*IMAGES/sqrt(mean(var(IMAGES))); save MY_IMAGES IMAGES """ import numpy as np def whiten(X, f0=None, n=None): """ :param X: a 3D tensor n_images x rows x cols :param f0: filter parameter (see docs) :param n: filter parameter (see docs) :returns: 3D tensor n_images x rows x cols of filtered images. """ # May be mixed up with the size2 and size1s because matlab does things # differnetly R, C = X.shape[-2:] if R %2: raise NotImplementedError() if C %2: raise NotImplementedError() if f0 is None: f0 = .4 * min(R,C) if n is None: n = 4 fx,fy = np.mgrid[-R/2:R/2, -C/2:C/2] rho=np.sqrt(fx**2 + fy**2) filt=rho * np.exp( - (rho/f0)**4) If = np.fft.fft2(X) imagew=np.real(np.fft.ifft2(If * np.fft.fftshift(filt))) assert imagew.shape == X.shape return imagew