view datasets/shapeset1.py @ 513:6103dc5d2a0d

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 30 Oct 2008 19:39:26 -0400
parents 6fe692b93b69
children
line wrap: on
line source

"""
Routines to load/access Shapeset1
"""

from __future__ import absolute_import

import os
import numpy

from ..amat import AMat
from .config import data_root

def _head(path, n):
    dat = AMat(path=path, head=n)

    try:
        assert dat.input.shape[0] == n
        assert dat.target.shape[0] == n
    except Exception , e:
        raise Exception("failed to read %i lines from file %s" % (n, path))

    return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0])


def head_train(n=10000):
    """Load the first Shapeset1 training examples.

    Returns two matrices: x, y.  x has N rows of 1024 columns.  Each row of x represents the
    32x32 grey-scale pixels in raster order.  y is a vector of N integers.  Each element y[i]
    is the label of the i'th row of x.
    
    """
    path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.10000.train.shape.amat')
    return _head(path, n)

def head_valid(n=5000):
    """Load the first Shapeset1 validation examples.

    Returns two matrices: x, y.  x has N rows of 1024 columns.  Each row of x represents the
    32x32 grey-scale pixels in raster order.  y is a vector of N integers.  Each element y[i]
    is the label of the i'th row of x.
    
    """
    path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.valid.shape.amat')
    return _head(path, n)

def head_test(n=5000):
    """Load the first Shapeset1 testing examples.

    Returns two matrices: x, y.  x has N rows of 1024 columns.  Each row of x represents the
    32x32 grey-scale pixels in raster order.  y is a vector of N integers.  Each element y[i]
    is the label of the i'th row of x.
    
    """
    path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.test.shape.amat')
    return _head(path, n)

def train_valid_test(ntrain=10000, nvalid=5000, ntest=5000):
    return head_train(n=ntrain), head_valid(n=nvalid), head_test(n=ntest)