有多种方法可以重塑 PyTorch 张量。您可以将这些方法应用于任何维度的张量。
让我们从一个二维2 x 3
张量开始:
x = torch.Tensor(2, 3)
print(x.shape)
# torch.Size([2, 3])
为了给这个问题增加一些鲁棒性,让我们2 x 3
通过在前面添加一个新维度和在中间添加另一个维度来重塑张量,从而产生一个1 x 2 x 1 x 3
张量。
方法1:添加维度None
使用(aka ) 的NumPy 样式插入None
np.newaxis
在您想要的任何位置添加尺寸。见这里。
print(x.shape)
# torch.Size([2, 3])
y = x[None, :, None, :] # Add new dimensions at positions 0 and 2.
print(y.shape)
# torch.Size([1, 2, 1, 3])
方法2:解压
使用torch.Tensor.unsqueeze(i)
(又名torch.unsqueeze(tensor, i)
或就地版本unsqueeze_()
)在第 i 个维度添加一个新维度。返回的张量与原始张量共享相同的数据。在本例中,我们可以使用unqueeze()
两次来添加两个新维度。
print(x.shape)
# torch.Size([2, 3])
# Use unsqueeze twice.
y = x.unsqueeze(0) # Add new dimension at position 0
print(y.shape)
# torch.Size([1, 2, 3])
y = y.unsqueeze(2) # Add new dimension at position 2
print(y.shape)
# torch.Size([1, 2, 1, 3])
在 PyTorch 的实践中,为批处理添加额外的维度可能很重要,因此您可能经常会看到unsqueeze(0)
.
方法三:查看
用于torch.Tensor.view(*shape)
指定所有尺寸。返回的张量与原始张量共享相同的数据。
print(x.shape)
# torch.Size([2, 3])
y = x.view(1, 2, 1, 3)
print(y.shape)
# torch.Size([1, 2, 1, 3])
方法四:重塑
使用torch.Tensor.reshape(*shape)
(aka torch.reshape(tensor, shapetuple)
) 指定所有维度。如果原始数据是连续的并且具有相同的步幅,则返回的张量将是输入的视图(共享相同的数据),否则将是副本。此函数类似于 NumPyreshape()
函数,因为它允许您定义所有维度并可以返回视图或副本。
print(x.shape)
# torch.Size([2, 3])
y = x.reshape(1, 2, 1, 3)
print(y.shape)
# torch.Size([1, 2, 1, 3])
此外,作者在 O'Reilly 2019 年出版的Programming PyTorch for Deep Learning中写道:
view()
现在您可能想知道和之间有什么区别reshape()
。答案是view()
作为原始张量的视图运行,因此如果基础数据发生更改,视图也会更改(反之亦然)。但是,view()
如果所需的视图不连续,可能会引发错误;也就是说,如果从头开始创建所需形状的新张量,它不会共享相同的内存块。如果发生这种情况,您必须先致电tensor.contiguous()
,然后才能使用view()
. 但是,reshape()
所有这些都是在幕后完成的,所以总的来说,我建议使用reshape()
而不是view()
.
方法5:resize_
使用就地函数torch.Tensor.resize_(*sizes)
修改原始张量。该文档指出:
警告。这是一种低级方法。存储被重新解释为 C 连续,忽略当前步幅(除非目标大小等于当前大小,在这种情况下张量保持不变)。在大多数情况下,您将改为使用view()
,它检查连续性,或reshape()
,它在需要时复制数据。要使用自定义步幅就地更改大小,请参阅set_()
。
print(x.shape)
# torch.Size([2, 3])
x.resize_(1, 2, 1, 3)
print(x.shape)
# torch.Size([1, 2, 1, 3])
我的观察
如果您只想添加一个维度(例如,为批次添加第 0 个维度),请使用unsqueeze(0)
. 如果您想完全改变维度,请使用reshape()
.
也可以看看:
pytorch中的reshape和view有什么区别?
view() 和 unsqueeze() 有什么区别?
在 PyTorch 0.4 中,是否建议在可能的情况下reshape
使用view
?