Mercurial > pylearn
view pylearn/dataset_ops/protocol.py @ 1223:621e03253f0c
mini-bug in taglist
author | boulanni <nicolas_boulanger@hotmail.com> |
---|---|
date | Wed, 22 Sep 2010 15:08:39 -0400 |
parents | 5cbd235f8fb8 |
children | f49801e39fe3 |
line wrap: on
line source
"""Convenience base classes to help with writing Dataset ops """ __docformat__ = "restructuredtext_en" import theano class Dataset(theano.Op): """ The basic dataset interface is an expression that maps an integer to a dataset element. There is also a minibatch option, in which the expression maps an array of integers to a list or array of dataset elements. """ def __init__(self, single_type, batch_type): self.single_type = single_type self.batch_type = batch_type def make_node(self, idx): _idx = theano.tensor.as_tensor_variable(idx) if not _idx.dtype.startswith('int'): raise TypeError() if _idx.ndim == 0: # one example at a time otype = self.single_type elif _idx.ndim == 1: #many examples at a time otype = self.batch_type else: raise TypeError(idx) return theano.Apply(self, [_idx], [otype()]) def __eq__(self, other): return type(self) == type(other) \ and self.single_type == other.single_type \ and self.batch_type == other.batch_type def __hash__(self): return hash(type(self)) ^ hash(self.single_type) ^ hash(self.batch_type) def __str__(self): return "%s{%s,%s}" % (self.__class__.__name__, self.single_type, self.batch_type) def grad(self, inputs, g_outputs): return [None for i in inputs] class TensorDataset(Dataset): """A convenient base class for Datasets whose elements all have the same TensorType. """ def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None): single_broadcastable = tuple(single_broadcastable) single_type = theano.tensor.Tensor( broadcastable=single_broadcastable, dtype=dtype) #shape=single_shape) batch_type = theano.tensor.Tensor( broadcastable=(False,)+single_type.broadcastable, dtype=dtype) #shape=(batch_size,)+single_type.shape) super(TensorDataset, self).__init__(single_type, batch_type) class TensorFnDataset(TensorDataset): """A good base class for TensorDatasets that can be read from disk and cached in memory The dataset is accessed via a function call to make this Op pickle-able. If the function is a normal module-level function, then this Op will be picklable. If the dataset were a property of this Op, then pickling the Op would require pickling the entire dataset. """ def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None): """ :type fn: callable or (callable, args) tuple [MUST BE PICKLABLE!] :param fn: function that returns the dataset as a tensor. Leading index is the example index, others are considered part of each example. """ super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size) try: self.fn, self.fn_args = fn except: self.fn, self.fn_args = fn, () def __eq__(self, other): return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn \ and self.fn_args == other.fn_args def __hash__(self): return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) ^ hash(self.fn_args) def __str__(self): try: return "%s{%s,%s}" % (self.__class__.__name__, self.fn.__name__, self.fn_args) except: return "%s{%s}" % (self.__class__.__name__, self.fn, self.fn_args) def perform(self, node, (idx,), (z,)): x = self.fn(*self.fn_args) if idx.ndim == 0: z[0] = x[int(idx)] else: z[0] = x[idx]