0

我正在尝试实施变压器并停留在某一点。

假设我有形状 [2,20] 的输入序列,其中 2 是样本数,20 是序列中的单词数(序列长度)。

因此,我创建了一个形状为 [1,20] 的数组,如 [0,1,2, ... 19]。现在我想堆叠它,最终形状应该是 [2,20] 以与输入序列一致。像下面

[[0,1,2, ... 19],
[0,1,2, ... 19]]

是否有这样做的火炬功能。我可以循环并创建数据和数组,但想避免它。

4

1 回答 1

0

如果要堆叠的张量的形状为 [1,20],则可以使用 torch.cat()

t1 = torch.zeros([1,5]) # tensor([[0., 0., 0., 0., 0.]])
t2 = torch.ones([1,5]) # tensor([[1., 1., 1., 1., 1.]])

torch.cat([t1, t2]) # tensor([[0., 0., 0., 0., 0.],
                              [1., 1., 1., 1., 1.]])

如果张量是一维的,你可以简单地使用 torch.stack()

t1 = torch.zeros([5]) # tensor([0., 0., 0., 0., 0.])
t2 = torch.ones([5]) # tensor([1., 1., 1., 1., 1.])

torch.stack([t1, t2]) # tensor([[0., 0., 0., 0., 0.],
                                [1., 1., 1., 1., 1.]])

现在,对于您的情况,您可以使用更短的方法:

torch.arange(0,20).repeat(2,1) # tensor([[0,1,2, ... 19],
                                         [0,1,2, ... 19]])

于 2021-11-10T03:55:28.397 回答