Mercurial > pylearn
diff onehotop.py @ 356:18702ceb2096
Added more functions
author | Joseph Turian <turian@iro.umontreal.ca> |
---|---|
date | Thu, 19 Jun 2008 16:18:37 -0400 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/onehotop.py Thu Jun 19 16:18:37 2008 -0400 @@ -0,0 +1,58 @@ +""" +One hot Op +""" + +#from theano import tensor +from theano.tensor import as_tensor, Tensor +from theano.gof import op +from theano.gof.graph import Apply + +import numpy + +class OneHot(op.Op): + """ + Construct a one-hot vector, x out of y. + + @todo: Document inputs and outputs + @todo: Use 'bool' as output dtype? Or, at least 'int64' ? Not float64! + @todo: Use 'bool' as output dtype, not 'int64' ? + @todo: Allow this to operate on column vectors (Tensor) + @todo: Describe better. + """ + + def make_node(self, x, y): + """ + @type x: Vector L{Tensor} of integers + @param x: The entries of the one-hot vector to be one. + @type y: Integer scalar L{Tensor} + @param y: The length (#columns) of the one-hot vectors. + @return: A L{Tensor} of one-hot vectors + + @precondition: x < y for all entries of x + @todo: Check that x and y are int types + """ + x = as_tensor(x) + y = as_tensor(y) + #assert x.dtype[0:3] == "int" + #assert y.dtype[0:3] == "int" + inputs = [x, y] + ##outputs = [tensor.Tensor("int64", broadcastable=[False, False])] + #outputs = [tensor.Tensor("float64", broadcastable=[False, False])] + #outputs = [Tensor("int64", broadcastable=[False, False])] + outputs = [Tensor("float64", broadcastable=[False, False]).make_result()] + node = Apply(op = self, inputs = inputs, outputs = outputs) + return node + + def perform(self, node, (x, y), (out, )): + assert x.dtype == "int64" or x.dtype == "int32" + assert x.ndim == 1 + assert y.dtype == "int64" or x.dtype == "int32" + assert y.ndim == 0 + out[0] = numpy.zeros((x.shape[0], y), dtype="float64") + for c in range(x.shape[0]): + assert x[c] < y + out[0][c, x[c]] = 1 + + def grad(self, (x, y), (out_gradient, )): + return None, None +one_hot = OneHot()