view pylearn/datasets/flickr.py @ 1479:1b69d435f09f

fix error string.
author Frederic Bastien <nouiz@nouiz.org>
date Wed, 25 May 2011 09:26:47 -0400
parents 4a7d413c3425
children
line wrap: on
line source

"""
Routines to load variations on the Flickr image dataset.
"""
from __future__ import absolute_import

import os
import numpy

from ..io import filetensor
from .config import data_root
from .dataset import Dataset


path_test_10class ='flickr_10classes_test.ft'

path_train_10class = 'flickr_10classes_train.ft'

path_valid_10class = 'flickr_10classes_valid.ft'

def basic_10class(folder = None):
    """Return the basic flickr image classification problem.
    The images are 75x75, and there are 7500 training examples.
    """
    root = os.path.join(data_root(), 'flickr') if folder is None else folder
    train = filetensor.read(open(os.path.join(root, path_train_10class)))
    valid = filetensor.read(open(os.path.join(root, path_valid_10class)))
    test = filetensor.read(open(os.path.join(root, path_test_10class)))

    assert train.shape[1] == 75*75 +1
    assert valid.shape[1] == 75*75 +1
    assert test.shape[1] == 75*75 +1

    rval = Dataset()

    rval.train = Dataset.Obj(
            x=train[:, 0:-1],
            y=numpy.asarray(train[:, -1], dtype='int64'))
    rval.valid = Dataset.Obj(
            x=valid[:, 0:-1],
            y=numpy.asarray(valid[:, -1], dtype='int64'))
    rval.test = Dataset.Obj(
            x=test[:, 0:-1],
            y=numpy.asarray(test[:, -1], dtype='int64'))

    rval.n_classes = 10
    rval.img_shape = (75,75)

    return rval

def translations_10class():
    raise NotImplementedError('TODO')


def render_a_few_images(n=10, prefix='flickr_img', suffix='png'):
    #TODO: document this and move it to a more common 
    #      place where other datasets can use it
    from PIL import Image
    root = os.path.join(data_root(), 'flickr')
    valid = filetensor.read(open(os.path.join(root, path_valid_10class)))
    assert valid.shape == (1000,75*75+1)
    for i in xrange(n):
        pixelarray = valid[i,0:-1].reshape((75,75)).T
        assert numpy.all(pixelarray >= 0)
        assert numpy.all(pixelarray <= 1)

        pixel_uint8 = numpy.asarray( pixelarray * 255.0, dtype='uint8')
        im = Image.frombuffer('L', pixel_uint8.shape, pixel_uint8.data, 'raw', 'L', 0, 1)
        im.save(prefix + str(i) + '.' + suffix)