view pylearn/datasets/testDataset.py @ 1479:1b69d435f09f

fix error string.
author Frederic Bastien <nouiz@nouiz.org>
date Wed, 25 May 2011 09:26:47 -0400
parents 16f91ca016b1
children
line wrap: on
line source

"""
Various routines to load/access MNIST data.
"""
from __future__ import absolute_import

import os
import numpy

from ..io.amat import AMat
from .config import data_root
from .dataset import dataset_factory, Dataset

VALSEQ, VALRAND = range(2)

@dataset_factory('DEBUG')
def mnist_factory(variant='', ntrain=10, nvalid=10, ntest=10, \
        nclass=2, ndim=1, dshape=None, valtype=VALSEQ):

    temp = []
    [temp.append(5) for i in range(ndim)]
    dshape = temp if dshape is None else dshape

    rval = Dataset()
    rval.n_classes = nclass
    rval.img_shape = dshape

    dsize = numpy.prod(dshape);

    print ntrain, nvalid, ntest, nclass, dshape, valtype

    ntot = ntrain + nvalid + ntest
    xdata = numpy.arange(ntot*numpy.prod(dshape)).reshape((ntot,dsize)) \
            if valtype is VALSEQ else \
            numpy.random.random((ntot,dsize));
    ydata = numpy.round(numpy.random.random(ntot));

    rval.train = Dataset.Obj(x=xdata[0:ntrain],y=ydata[0:ntrain])
    rval.valid = Dataset.Obj(x=xdata[ntrain:ntrain+nvalid],\
                             y=ydata[ntrain:ntrain+nvalid])
    rval.test =  Dataset.Obj(x=xdata[ntrain+nvalid:ntrain+nvalid+ntest],
                             y=ydata[ntrain+nvalid:ntrain+nvalid+ntest])

    return rval