view pylearn/dataset_ops/protocol.py @ 1470:94268a161925

memmap support in Dataset op
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 18 May 2011 10:50:21 -0400
parents cff305ad9f60
children
line wrap: on
line source

"""Convenience base classes to help with writing Dataset ops

"""

__docformat__  = "restructuredtext_en"
import numpy
import theano

class Dataset(theano.Op):
    """ 
    The basic dataset interface is an expression that maps an integer to a dataset element.

    There is also a minibatch option, in which the expression maps an array of integers to a
    list or array of dataset elements.
    """
    def __init__(self, single_type, batch_type):
        self.single_type = single_type
        self.batch_type = batch_type

    def make_node(self, idx):
        _idx = theano.tensor.as_tensor_variable(idx)
        if not _idx.dtype.startswith('int'):
            raise TypeError()
        if _idx.ndim == 0: # one example at a time
            otype = self.single_type
        elif _idx.ndim == 1: #many examples at a time
            otype = self.batch_type
        else:
            raise TypeError(idx)
        return theano.Apply(self, [_idx], [otype()])

    def __eq__(self, other):
        return type(self) == type(other) \
                and self.single_type == other.single_type \
                and self.batch_type == other.batch_type

    def __hash__(self):
        return hash(type(self)) ^ hash(self.single_type) ^ hash(self.batch_type)

    def __str__(self):
        return "%s{%s,%s}" % (self.__class__.__name__, self.single_type, self.batch_type)

    def grad(self, inputs, g_outputs):
        return [None for i in inputs]


class TensorDataset(Dataset):
    """A convenient base class for Datasets whose elements all have the same TensorType.
    """
    def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None):
        single_broadcastable = tuple(single_broadcastable)
        self.single_shape = single_shape
        self.batch_size = batch_size
        single_type = theano.tensor.Tensor(
                broadcastable=single_broadcastable,
                dtype=dtype)
        batch_type = theano.tensor.Tensor(
                broadcastable=(False,)+single_type.broadcastable,
                dtype=dtype)
        super(TensorDataset, self).__init__(single_type, batch_type)
    def __eq__(self, other):
        return (super(TensorDataset, self).__eq__(other)
                and self.single_shape == other.single_shape
                and self.batch_size == other.batch_size)
    def __hash__(self):
        return (super(TensorDataset, self).__hash__()
                ^ hash(self.single_shape)
                ^ hash(self.batch_size))

class TensorFnDataset(TensorDataset):
    """A good base class for TensorDatasets that are backed by indexed objects.
    E.g. numpy ndarrays and memmaps.

    This Op looks up the dataset by a function call, rather than by storing it
    as a member variable.  This is done to make the graph serializable without
    having to save the dataset itself, which is typically large.

    This Op is picklable if (and only if) the function that accesses the dataset
    can be serialized.
    """
    def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None):
        """
        :type fn: callable or (callable, args) tuple [MUST BE PICKLABLE!]
        :param fn: function that returns the dataset as a ndarray-like object.

        :type bcast: tuple of bool
        :param bcast: the broadcastable flag for the return value if this op is
            indexed by a scalar (the one example case)  A (False,) will be
            pre-pended to this pattern when the Op is indexed by a vector.
        """
        super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size)
        try:
            self.fn, self.fn_args = fn
        except:
            self.fn, self.fn_args = fn, ()
    def __getstate__(self):
        rval = dict(self.__dict__)
        if 'x_' in rval:
            del rval['x_']
        return rval

    def __eq__(self, other):
        return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \
                and self.fn_args == other.fn_args

    def __hash__(self):
        return (super(TensorFnDataset, self).__hash__()
                ^ hash(self.fn)
                ^ hash(self.fn_args))

    def __str__(self):
        try:
            return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args)
        except:
            return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args)

    def perform(self, node, (idx,), (z,)):
        try:
            x = self.x_
        except:
            x = self.x_ = self.fn(*self.fn_args)
        if idx.ndim == 0:
            z[0] = numpy.asarray(x[int(idx)]) # asarray is important for memmaps
        else:
            z[0] = numpy.asarray(x[idx]) # asarray is important for memmaps