view pylearn/datasets/nade.py @ 1482:be4a49a65333

modified Nade dataset to use new config.get_filepath_in_roots mechanism
author gdesjardins
date Tue, 05 Jul 2011 10:56:40 -0400
parents b24ed2aa077e
children f7b348e6a98e
line wrap: on
line source

import os
import numpy

from pylearn.io.pmat import PMat
from pylearn.datasets.config import data_root # config
from pylearn.datasets.dataset import Dataset
import config

def load_dataset(name=None):
    """
    Various datasets which were used in the following paper.
    The Neural Autoregressive Distribution Estimator
    Hugo Larochelle and Iain Murray, AISTATS 2011

    :param name: string specifying which dataset to load
    :return: Dataset object
    dataset.train.x: matrix of training data of shape (num_examples, ndim)
    dataset.train.y: vector of training labels of length num_examples. Labels are
                     integer valued and represent the class it belongs too.
    dataset.valid.x: idem for validation data
    dataset.valid.y: idem for validation data
    dataset.test.x: idem for test data
    dataset.test.y: idem for test data

    WARNING: class labels are integer-valued instead of 1-of-n encoding !
    """
    assert name in ['adult','binarized_mnist', 'mnist', 'connect4','dna',
                    'mushrooms','nips','ocr_letters','rcv1','web']
    rval = Dataset()

    # dataset lookup through $PYLEARN_DATA_ROOT
    _path = os.path.join('larocheh', name)
    path = config.get_filepath_in_roots(_path)

    # load training set
    x=numpy.load(os.path.join(path,'train_data.npy'))
    y_fname = os.path.join(path, 'train_labels.npy')
    if os.path.exists(y_fname):
        y = numpy.load(os.path.join(path,'train_labels.npy'))
    else:
        y = None
    rval.train = Dataset.Obj(x=x, y=y)
 
    # load validation set
    x=numpy.load(os.path.join(path,'valid_data.npy'))
    y_fname = os.path.join(path, 'valid_labels.npy')
    if os.path.exists(y_fname):
        y = numpy.load(os.path.join(path,'valid_labels.npy'))
    else:
        y = None
    rval.valid = Dataset.Obj(x=x, y=y)
                             
    # load training set
    x=numpy.load(os.path.join(path,'test_data.npy'))
    y_fname = os.path.join(path, 'test_labels.npy')
    if os.path.exists(y_fname):
        y = numpy.load(os.path.join(path,'test_labels.npy'))
    else:
        y = None
    rval.test = Dataset.Obj(x=x, y=y)

    return rval