Mercurial > pylearn
changeset 654:2704c8688ced
merge
author | bergstra@mlp4.ais.sandbox |
---|---|
date | Wed, 11 Feb 2009 01:43:14 -0500 |
parents | d3d8f5a17909 (diff) d03b5d8e4bf6 (current diff) |
children | 14d22ca1c8b5 d69e668ab904 |
files | bin/dbdict-query bin/dbdict-run bin/dbdict-run-job pylearn/dbdict/__init__.py pylearn/dbdict/api0.py pylearn/dbdict/crap.py pylearn/dbdict/dbdict_run.py pylearn/dbdict/dbdict_run_sql.py pylearn/dbdict/dconfig.py pylearn/dbdict/design.txt pylearn/dbdict/experiment.py pylearn/dbdict/newstuff.py pylearn/dbdict/sample_create_jobs.py pylearn/dbdict/scratch.py pylearn/dbdict/sql.py pylearn/dbdict/sql_commands.py pylearn/dbdict/test_api0.py pylearn/dbdict/tests/test_experiment.py pylearn/dbdict/tools.py |
diffstat | 3 files changed, 90 insertions(+), 3 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/datasets/MNIST.py Wed Feb 04 20:02:05 2009 -0500 +++ b/pylearn/datasets/MNIST.py Wed Feb 11 01:43:14 2009 -0500 @@ -7,7 +7,7 @@ import numpy from ..io.amat import AMat -from .config import data_root +from .config import data_root # config from .dataset import Dataset def head(n=10, path=None):
--- a/pylearn/datasets/config.py Wed Feb 04 20:02:05 2009 -0500 +++ b/pylearn/datasets/config.py Wed Feb 11 01:43:14 2009 -0500 @@ -4,10 +4,13 @@ Especially, the locations of data files. """ -import os +import os, sys def env_get(key, default): + if os.getenv(key) is None: + print >> sys.stderr, "WARNING: Environment variable", key, + print >> sys.stderr, "is not set. Using default of", default return default if os.getenv(key) is None else os.getenv(key) def data_root(): - return env_get('PYLEARN_DATA_ROOT', '/u/bergstrj/pub/data/') + return env_get('PYLEARN_DATA_ROOT', os.getenv('HOME')+'/data')
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/io/image_tiling.py Wed Feb 11 01:43:14 2009 -0500 @@ -0,0 +1,84 @@ +""" +Illustrate filters (or data) in a grid of small image-shaped tiles. +""" + +import numpy +from PIL import Image + +def scale_to_unit_interval(ndar): + ndar = ndar.copy() + ndar -= ndar.min() + ndar *= 1.0 / ndar.max() + return ndar + +def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0,0), + scale_rows_to_unit_interval=True, + output_pixel_vals=True + ): + """ + Transform an array with one flattened image per row, into an array in which images are + reshaped and layed out like tiles on a floor. + + This function is useful for visualizing datasets whose rows are images, and also columns of + matrices for transforming those rows (such as the first layer of a neural net). + + :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can be 2-D ndarrays or None + :param X: a 2-D array in which every row is a flattened image. + :type img_shape: tuple; (height, width) + :param img_shape: the original shape of each image + :type tile_shape: tuple; (rows, cols) + :param tile_shape: the number of images to tile (rows, cols) + + :returns: array suitable for viewing as an image. (See:`PIL.Image.fromarray`.) + :rtype: a 2-d array with same dtype as X. + + """ + assert len(img_shape) == 2 + assert len(tile_shape) == 2 + assert len(tile_spacing) == 2 + + out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp + in zip(img_shape, tile_shape, tile_spacing)] + + if isinstance(X, tuple): + assert len(X) == 4 + if output_pixel_vals: + out_array = numpy.zeros((out_shape[0], out_shape[1], 4), dtype='uint8') + else: + out_array = numpy.zeros((out_shape[0], out_shape[1], 4), dtype=X.dtype) + + #colors default to 0, alpha defaults to 1 (opaque) + if output_pixel_vals: + channel_defaults = [0,0,0,255] + else: + channel_defaults = [0.,0.,0.,1.] + + for i in xrange(4): + if X[i] is None: + out_array[:,:,i] = numpy.zeros(out_shape, + dtype='uint8' if output_pixel_vals else out_array.dtype + )+channel_defaults[i] + else: + out_array[:,:,i] = tile_raster_images(X[i], img_shape, tile_shape, tile_spacing, scale_rows_to_unit_interval, output_pixel_vals) + return out_array + + else: + H, W = img_shape + Hs, Ws = tile_spacing + + out_array = numpy.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype) + for tile_row in xrange(tile_shape[0]): + for tile_col in xrange(tile_shape[1]): + if tile_row * tile_shape[1] + tile_col < X.shape[0]: + if scale_rows_to_unit_interval: + this_img = scale_to_unit_interval(X[tile_row * tile_shape[1] + tile_col].reshape(img_shape)) + else: + this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape) + out_array[ + tile_row * (H+Hs):tile_row*(H+Hs)+H, + tile_col * (W+Ws):tile_col*(W+Ws)+W + ] \ + = this_img * (255 if output_pixel_vals else 1) + return out_array + +