view pylearn/datasets/miniblocks.py @ 1379:d90971353e22

remove reference to fringant2 that don't exist anymore.
author Frederic Bastien <nouiz@nouiz.org>
date Thu, 02 Dec 2010 12:46:50 -0500
parents d15e4f803622
children
line wrap: on
line source

# Interface to miniblocks dataset.

import herding, numpy
import herding.dataset

from pylearn.datasets import Dataset

def miniblocks(reweight=None, use_inverse=False):
    # If 'reweight' is not None, then it is an integer N such that each
    # sample is duplicated k times, with k taken uniformly in {1, 2, ..., N}.
    # Some adjustment is made to ensure the dataset size is a multiple of its
    # original size.
    data = herding.dataset.Miniblocks(4, batchsize = -1, forever = False,
            zeroone = True)

    input, target = iter(data).next()

    #from plearn.pyext import pl
    #data = pl.AutoVMatrix(filename='/u/delallea/LisaPLearn/UserExp/delallea/perso/gen_compare/1DBall_12.amat').getMat()
    #data = pl.AutoVMatrix(filename='/data/lisa/exp/delallea/python_modules/LeDeepNet/mnist_binarized.pmat').getMat()
    #input = data

    # Note that the target being returned seems to be a dummy target. So
    # instead, we fill it with zeros.
    target = numpy.zeros((len(input), 1))

    if reweight is not None:
        assert isinstance(reweight, int)
        rgen = numpy.random.RandomState(1827)
        weights = rgen.randint(1, reweight + 1, size = len(input))
        new_length = numpy.sum(weights)
        while new_length % len(input) > 0:
            to_prune = rgen.randint(len(input))
            if weights[to_prune] > 1:
                weights[to_prune] -= 1
                new_length -= 1
        assert new_length == numpy.sum(weights)
        new_input = numpy.zeros((new_length, input.shape[1]))
        new_target = numpy.zeros((new_length, target.shape[1]))
        idx = 0
        for w, i, t in zip(weights, input, target):
            for k in range(w):
                new_input[idx, :] = i
                new_target[idx, :] = t
                idx += 1
        input = new_input
        target = new_target
        print 'Dataset size after reweighting: %s' % (input.shape, )

    set = Dataset()
    set.train = Dataset.Obj(x = input, y = target)
    set.test = Dataset.Obj(x = input, y = target)
    set.img_shape = (4,4)

    return set