Mercurial > pylearn
changeset 966:e88d7b7d53ed
adding algorithms/sparse_coding
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 20 Aug 2010 13:58:41 -0400 |
parents | bf54637bb994 |
children | 90e11d5d0a41 |
files | pylearn/algorithms/sparse_coding.py |
diffstat | 1 files changed, 159 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/algorithms/sparse_coding.py Fri Aug 20 13:58:41 2010 -0400 @@ -0,0 +1,159 @@ +import sys, logging, os +import numpy, PIL.Image, scipy.optimize +import theano + + +from theano import shared, function +import theano.tensor as TT +from theano.tensor import matrix, vector, scalar, dot, grad, switch, clip +floatX = theano.config.floatX + +from pylearn.io.image_tiling import tile_raster_images +from pylearn.datasets import image_patches + +def sample_codebook_prior(shape, rng, dtype=floatX): + rval = numpy.asarray(rng.randn(*shape), dtype=dtype) + for i, img in enumerate(rval): + rval[i] = img / numpy.sqrt( (img**2).sum()) + print >> sys.stderr, "TODO: pick a codebook prior" + return rval + +def numpy_project_onto_ball(X): + """Return a copy of X with rows scaled to unit length + """ + norms = numpy.sqrt((X**2).sum(axis=1)) + norms.shape = (X.shape[0], 1) + return X / norms + +def reproduce_olshausen(): + + rng = numpy.random.RandomState(89234) + ## + # load the data into X + ## + + #TODO: Make sure this new way of loading the data still gives the + # right results (check mean and variance!) + X = numpy.asarray(image_patches.load_patches().train.x, dtype=floatX) + img_shape = (20,20) + PIL.Image.fromarray( + tile_raster_images(X[:100], (20,20), (10,10), (1,1)), + 'L').save('X100.png') + + X -= numpy.mean(X, axis=0) + + # Constants below are set up for a variance of .01 + X /= 10*numpy.std(X, axis=0)+1e-8 + + batchsize=100 + tile_dims=(24,24) + nZ=tile_dims[0]*tile_dims[1] + + #C = numpy.asarray((X.shape[0], nZ), dtype=floatX) + + ## + # symbolic stuff + ## + sX = matrix() + flat_C = vector() + sC = flat_C.reshape((batchsize, nZ)) + sZ = matrix() + + sXpred = dot(sC, sZ) + + cost_X = ((sX - sXpred)**2).sum() + cost_C = TT.log(1+(sC/.116)**2).sum() + cost = 100.0 * cost_X + 2.2 * cost_C + gC, gZ = grad(cost, [sC, sZ]) + + cost_fn = function([sX, flat_C, sZ], cost) + debug_fn = function([sX, flat_C, sZ], [cost, cost_X, cost_C]) + gC_fn = function([sX, flat_C, sZ], gC.flatten()) + + gC_gZ_fn = function([sX, sC, sZ], [cost, gC,gZ]) + + # sample some imgs from the dictionary prior + Z = sample_codebook_prior((nZ, X.shape[1]), rng) + + + # loop over the data by 100 examples at a time + for j in xrange(1000): + for i in xrange(10): + offset = ((j*10+i)*batchsize) % len(X) + Xi = X[offset:offset+batchsize] + if len(Xi) != batchsize: + continue + + #Ci = C[offset:offset+batchsize] + + # get the optimal C given X and Z by cg + Ci = scipy.optimize.fmin_cg( + f=lambda c:cost_fn(Xi, c, Z), + x0=numpy.dot(Xi, Z.T).flatten(), + fprime=lambda c: gC_fn(Xi, c, Z), + maxiter=20, + gtol=1e-2, + ) + Ci.shape = (batchsize, nZ) + + print "j", j, + print "i", i, + print "Ci**2", (Ci**2).sum(), + + if 1: # use published algo + + Xi_residual = Xi - numpy.dot(Ci, Z) + + dZ = numpy.dot(Ci.T, Xi_residual) + + #print "solution", debug_fn(Xi, Ci.flatten(), Z) + print "dXi**2", (Xi_residual**2).sum(), + print "dZ**2", (dZ**2).sum() + + eta = 3e-1 + + Z += eta * dZ + Z = numpy_project_onto_ball(Z) + if 0: + # use joint optimization of Z and C + # This is a little faster, but not much. + # + # It is important not to change Z too much for any C + # because small changes in Z can trigger large changes in the corresponding + # optimal C for a given X. That's a natural consequence of the active + # sparsification. + + eta = 1e-3 + for k in xrange(3): + l, dC,dZ = gC_gZ_fn(Xi, Ci, Z) + print "k", k, + print "l", l, + print "dCi**2", (dC**2).sum(), + print "dZ**2", (dZ**2).sum() + #Ci -= eta * dC + Z -= eta * dZ + Z = numpy_project_onto_ball(Z) + + hist,edges = numpy.histogram(abs(Ci)) + print "Hist" + for h, e in zip(hist, edges): + print '%.3f \t %.3f' % (e, h) + + if j < 100: + PIL.Image.fromarray( + tile_raster_images(Z, (20,20), tile_dims, (1,1)), + 'L').save('Z_j=%i.png'%j) + else: + if (j % 10) == 0: + PIL.Image.fromarray( + tile_raster_images(Z, (20,20), tile_dims, (1,1)), + 'L').save('Z_j=%i.png'%j) + + +if __name__ == '__main__': + logging.basicConfig(stream=sys.stderr) + logging.getLogger('main').setLevel(logging.INFO) + logging.getLogger('main').info('hello') + + # load olshausen images + reproduce_olshausen()