新答案
从 PyTorch 1.1 开始,one_hot
在torch.nn.functional
. 给定任何索引张量indices
和最大索引n
,您可以创建一个 one_hot 版本,如下所示:
n = 5
indices = torch.randint(0,n, size=(4,7))
one_hot = torch.nn.functional.one_hot(indices, n) # size=(4,7,n)
很老的答案
目前,根据我的经验,在 PyTorch 中切片和索引可能有点痛苦。我假设您不想将张量转换为 numpy 数组。目前我能想到的最优雅的方法是使用稀疏张量,然后转换为密集张量。这将按如下方式工作:
from torch.sparse import FloatTensor as STensor
batch_size = 4
seq_length = 6
feat_dim = 16
batch_idx = torch.LongTensor([i for i in range(batch_size) for s in range(seq_length)])
seq_idx = torch.LongTensor(list(range(seq_length))*batch_size)
feat_idx = torch.LongTensor([[5, 3, 2, 11, 15, 15], [1, 4, 6, 7, 3, 3],
[2, 4, 7, 8, 9, 10], [11, 12, 15, 2, 5, 7]]).view(24,)
my_stack = torch.stack([batch_idx, seq_idx, feat_idx]) # indices must be nDim * nEntries
my_final_array = STensor(my_stack, torch.ones(batch_size * seq_length),
torch.Size([batch_size, seq_length, feat_dim])).to_dense()
print(my_final_array)
注意:PyTorch 目前正在进行一些工作,这将在接下来的两三周内添加 numpy 风格的广播和其他功能以及其他功能。所以有可能,在不久的将来会有更好的解决方案。
希望这对您有所帮助。