Mercurial > pylearn
view pylearn/datasets/testDataset.py @ 1205:5525cf3faaa2
requirements: Question about the serialization requirement
author | Olivier Delalleau <delallea@iro> |
---|---|
date | Tue, 21 Sep 2010 10:48:01 -0400 |
parents | 16f91ca016b1 |
children |
line wrap: on
line source
""" Various routines to load/access MNIST data. """ from __future__ import absolute_import import os import numpy from ..io.amat import AMat from .config import data_root from .dataset import dataset_factory, Dataset VALSEQ, VALRAND = range(2) @dataset_factory('DEBUG') def mnist_factory(variant='', ntrain=10, nvalid=10, ntest=10, \ nclass=2, ndim=1, dshape=None, valtype=VALSEQ): temp = [] [temp.append(5) for i in range(ndim)] dshape = temp if dshape is None else dshape rval = Dataset() rval.n_classes = nclass rval.img_shape = dshape dsize = numpy.prod(dshape); print ntrain, nvalid, ntest, nclass, dshape, valtype ntot = ntrain + nvalid + ntest xdata = numpy.arange(ntot*numpy.prod(dshape)).reshape((ntot,dsize)) \ if valtype is VALSEQ else \ numpy.random.random((ntot,dsize)); ydata = numpy.round(numpy.random.random(ntot)); rval.train = Dataset.Obj(x=xdata[0:ntrain],y=ydata[0:ntrain]) rval.valid = Dataset.Obj(x=xdata[ntrain:ntrain+nvalid],\ y=ydata[ntrain:ntrain+nvalid]) rval.test = Dataset.Obj(x=xdata[ntrain+nvalid:ntrain+nvalid+ntest], y=ydata[ntrain+nvalid:ntrain+nvalid+ntest]) return rval