view pylearn/dataset_ops/protocol.py @ 1417:f49801e39fe3

TensorDataset - added single_shape and batch_size properties
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 04 Feb 2011 16:02:57 -0500
parents 5cbd235f8fb8
children 383d4c061546
line wrap: on
line source

"""Convenience base classes to help with writing Dataset ops

"""

__docformat__  = "restructuredtext_en"
import theano

class Dataset(theano.Op):
    """ 
    The basic dataset interface is an expression that maps an integer to a dataset element.

    There is also a minibatch option, in which the expression maps an array of integers to a
    list or array of dataset elements.
    """
    def __init__(self, single_type, batch_type):
        self.single_type = single_type
        self.batch_type = batch_type

    def make_node(self, idx):
        _idx = theano.tensor.as_tensor_variable(idx)
        if not _idx.dtype.startswith('int'):
            raise TypeError()
        if _idx.ndim == 0: # one example at a time
            otype = self.single_type
        elif _idx.ndim == 1: #many examples at a time
            otype = self.batch_type
        else:
            raise TypeError(idx)
        return theano.Apply(self, [_idx], [otype()])

    def __eq__(self, other):
        return type(self) == type(other) \
                and self.single_type == other.single_type \
                and self.batch_type == other.batch_type

    def __hash__(self):
        return hash(type(self)) ^ hash(self.single_type) ^ hash(self.batch_type)

    def __str__(self):
        return "%s{%s,%s}" % (self.__class__.__name__, self.single_type, self.batch_type)

    def grad(self, inputs, g_outputs):
        return [None for i in inputs]


class TensorDataset(Dataset):
    """A convenient base class for Datasets whose elements all have the same TensorType.
    """
    def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None):
        single_broadcastable = tuple(single_broadcastable)
        self.single_shape = single_shape
        self.batch_size = batch_size
        single_type = theano.tensor.Tensor(
                broadcastable=single_broadcastable,
                dtype=dtype)
        batch_type = theano.tensor.Tensor(
                broadcastable=(False,)+single_type.broadcastable,
                dtype=dtype)
        super(TensorDataset, self).__init__(single_type, batch_type)
    def __eq__(self, other):
        return (super(TensorDataset, self).__eq__(other)
                and self.single_shape == other.single_shape
                and self.batch_size == other.batch_size)
    def __hash__(self):
        return (super(TensorDataset, self).__hash__()
                ^ hash(self.single_shape)
                ^ hash(self.batch_size))

class TensorFnDataset(TensorDataset):
    """A good base class for TensorDatasets that can be read from disk and cached in memory

    The dataset is accessed via a function call to make this Op pickle-able.  If the function
    is a normal module-level function, then this Op will be picklable.  If the dataset were a
    property of this Op, then pickling the Op would require pickling the entire dataset.
    """
    def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None):
        """
        :type fn: callable or (callable, args) tuple [MUST BE PICKLABLE!]

        :param fn: function that returns the dataset as a tensor. Leading index is the example
        index, others are considered part of each example.
        """
        super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size)
        try:
            self.fn, self.fn_args = fn
        except:
            self.fn, self.fn_args = fn, ()

    def __eq__(self, other):
        return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \
                and self.fn_args == other.fn_args

    def __hash__(self):
        return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) ^ hash(self.fn_args)

    def __str__(self):
        try:
            return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args)
        except:
            return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args)

    def perform(self, node, (idx,), (z,)):
        x = self.fn(*self.fn_args)
        if idx.ndim == 0:
            z[0] = x[int(idx)]
        else:
            z[0] = x[idx]