1

为什么torch::Tensor::is_same以下断言失败?使用 C++ PyTorch API 将张量写入文件,然后再次读入另一个张量,并is_same比较两个张量:

torch::Tensor x_sequence = torch::linspace(0, M_PI, 1000);    
torch::save(x_sequence, "x_sequence.dat");
torch::Tensor x_read;
torch::load(x_read, "x_sequence.dat");
assert(x_read.is_same(x_sequence));  

这导致:

int main(int, char**): Assertion `x_read.is_same(x_sequence)' failed.

使用

  • python-pytorch,Arch Linux 上的版本 1.6.0-2
  • g++ (GCC) 10.1.0
4

1 回答 1

2

torch::Tensor::is_same(const torch::Tensor& other)在这里定义。重要的是要注意 aTensor实际上是底层TensorImpl类的指针(它实际上保存数据)。

因此,当您调用 时is_same,实际上检查的是您的指针是否相同,即您的 2 个张量是否指向相同的底层内存。这是一个非常简单的示例,可以很好地理解它:

auto x = torch::randn({4,4});
auto copy = x;
auto clone = x.clone();
std::cout << x.is_same(copy) << " " << x.is_same(clone) << std::endl;
>>> 0 1

在这里,调用clone强制 pytorch 将数据复制到另一个内存位置。因此,指针不同并is_same返回 false。

如果您想实际比较这些值,您别无选择,只能计算两个张量之间的差异并计算该差异接近 0 的程度。

于 2020-08-18T16:27:47.537 回答