10

我有一个整数类标签的字节张量,例如来自 MNIST 数据集。

 1
 7
 5
[torch.ByteTensor of size 3]

如何使用它来创建 1-hot 向量的张量?

 1  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  1  0  0  0
 0  0  0  0  1  0  0  0  0  0
[torch.DoubleTensor of size 3x10]

我知道我可以用一个循环来做到这一点,但我想知道是否有任何聪明的 Torch 索引可以在一行中为我提供它。

4

2 回答 2

15
indices = torch.LongTensor{1,7,5}:view(-1,1)
one_hot = torch.zeros(3, 10)
one_hot:scatter(2, indices, 1)

您可以scattertorch/torch7 github 自述文件(在 master 分支中)找到文档。

于 2015-08-14T16:55:55.833 回答
2

另一种方法是从单位矩阵中洗牌:

indicies = torch.LongTensor{1,7,5}
one_hot = torch.eye(10):index(1, indicies)

这不是我的主意,我是在karpathy/char-rnn中找到的。

于 2016-11-09T12:16:33.423 回答