Mercurial > pylearn
view onehotop.py.scalar @ 468:a07948f780b9
Moved embeddings out of sandbox
author | Joseph Turian <turian@iro.umontreal.ca> |
---|---|
date | Tue, 21 Oct 2008 16:24:44 -0400 |
parents | 18702ceb2096 |
children |
line wrap: on
line source
""" One hot Op """ #from theano import tensor from theano.tensor import as_tensor, Tensor #from theano import scalar from theano.scalar import as_scalar from theano.gof import op from theano.gof.graph import Apply import numpy class OneHot(op.Op): """ Construct a one-hot vector, x out of y. @todo: Document inputs and outputs @todo: Use 'bool' as output dtype? Or, at least 'int64' ? Not float64! @todo: Use 'bool' as output dtype, not 'int64' ? @todo: Allow this to operate on column vectors (Tensor) @todo: Describe better. @todo: What type is y? @todo: What about operating on L{Scalar}s? """ def make_node(self, x, y): """ @type x: Vector L{Tensor} of integers @param x: The entries of the one-hot vector to be one. @type y: Integer L{Scalar} @param y: The length (#columns) of the one-hot vectors. @return: A L{Tensor} of one-hot vectors @precondition: x < y for all entries of x @todo: Check that x and y are int types """ #x = tensor.as_tensor(x) #y = scalar.as_scalar(y) x = as_tensor(x) y = as_scalar(y) #assert x.dtype[0:3] == "int" #assert y.dtype[0:3] == "int" inputs = [x, y] ##outputs = [tensor.Tensor("int64", broadcastable=[False, False])] #outputs = [tensor.Tensor("float64", broadcastable=[False, False])] #outputs = [Tensor("int64", broadcastable=[False, False])] outputs = [Tensor("float64", broadcastable=[False, False]).make_result()] node = Apply(op = self, inputs = inputs, outputs = outputs) return node def perform(self, node, (x, y), (out, )): assert x.dtype == "int64" assert type(y) == numpy.int64 assert x.ndim == 1 #out = numpy.zeros((x.shape[0], y), dtype="int64") out[0] = numpy.zeros((x.shape[0], y), dtype="float64") for c in range(x.shape[0]): assert x[c] < y out[0][c, x[c]] = 1 def grad(self, (x, y), (out_gradient, )): return None, None one_hot = OneHot()