Mercurial > pylearn
view pylearn/datasets/flickr.py @ 1198:1387771296a8
v2planning adding plugin_JB
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 20 Sep 2010 02:34:23 -0400 |
parents | 4a7d413c3425 |
children |
line wrap: on
line source
""" Routines to load variations on the Flickr image dataset. """ from __future__ import absolute_import import os import numpy from ..io import filetensor from .config import data_root from .dataset import Dataset path_test_10class ='flickr_10classes_test.ft' path_train_10class = 'flickr_10classes_train.ft' path_valid_10class = 'flickr_10classes_valid.ft' def basic_10class(folder = None): """Return the basic flickr image classification problem. The images are 75x75, and there are 7500 training examples. """ root = os.path.join(data_root(), 'flickr') if folder is None else folder train = filetensor.read(open(os.path.join(root, path_train_10class))) valid = filetensor.read(open(os.path.join(root, path_valid_10class))) test = filetensor.read(open(os.path.join(root, path_test_10class))) assert train.shape[1] == 75*75 +1 assert valid.shape[1] == 75*75 +1 assert test.shape[1] == 75*75 +1 rval = Dataset() rval.train = Dataset.Obj( x=train[:, 0:-1], y=numpy.asarray(train[:, -1], dtype='int64')) rval.valid = Dataset.Obj( x=valid[:, 0:-1], y=numpy.asarray(valid[:, -1], dtype='int64')) rval.test = Dataset.Obj( x=test[:, 0:-1], y=numpy.asarray(test[:, -1], dtype='int64')) rval.n_classes = 10 rval.img_shape = (75,75) return rval def translations_10class(): raise NotImplementedError('TODO') def render_a_few_images(n=10, prefix='flickr_img', suffix='png'): #TODO: document this and move it to a more common # place where other datasets can use it from PIL import Image root = os.path.join(data_root(), 'flickr') valid = filetensor.read(open(os.path.join(root, path_valid_10class))) assert valid.shape == (1000,75*75+1) for i in xrange(n): pixelarray = valid[i,0:-1].reshape((75,75)).T assert numpy.all(pixelarray >= 0) assert numpy.all(pixelarray <= 1) pixel_uint8 = numpy.asarray( pixelarray * 255.0, dtype='uint8') im = Image.frombuffer('L', pixel_uint8.shape, pixel_uint8.data, 'raw', 'L', 0, 1) im.save(prefix + str(i) + '.' + suffix)