view pylearn/datasets/flickr.py @ 601:fd95ff96dd47

updated flickr to row-major files
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 14 Jan 2009 17:00:57 -0500
parents e56303df3c77
children 28f7dc848efc
line wrap: on
line source

"""
Routines to load variations on the Flickr image dataset.
"""
from __future__ import absolute_import

import os
import numpy

from ..io import filetensor
from .config import data_root
from .dataset import Dataset


def test_10class():
    #TODO: make path an option,
    #TODO: make default path relative to data_root()
    f = open('flickr_10classes_test.ft')
    return filetensor.read(f)

def train_10class():
    #TODO: make path an option,
    #TODO: make default path relative to data_root()
    f = open('flickr_10classes_train.ft')
    return filetensor.read(f)

def valid_10class():
    #TODO: make path an option,
    #TODO: make default path relative to data_root()
    f = open('flickr_10classes_valid.ft')
    return filetensor.read(f)

def basic_10class():
    """Return the basic flickr image classification problem.
    The images are 75x75, and there are 7500 training examples.
    """
    train = train_10class()
    valid = valid_10class()
    test = test_10class()

    rval = Dataset()

    rval.train = Dataset.Obj(
            x=train[:, 0:-1],
            y=numpy.asarray(train[:, -1], dtype='int64'))
    rval.valid = Dataset.Obj(
            x=valid[:, 0:-1],
            y=numpy.asarray(valid[:, -1], dtype='int64'))
    rval.test = Dataset.Obj(
            x=test[:, 0:-1],
            y=numpy.asarray(test[:, -1], dtype='int64'))

    rval.n_classes = 10
    rval.img_shape = (75,75)
    return rval

def translations_10class():
    raise NotImplementedError('TODO')