1

当我尝试在屏幕上显示/打印一些张量时,我面临类似以下的情况,而不是获得最终结果,似乎 libtorch 显示带有乘数的张量(即0.01*,如下所示):

offsets.shape: [1, 4, 46, 85]
probs.shape: [46, 85]
offsets: (1,1,.,.) =
 0.01 *
  0.1006  1.2322
  -2.9587 -2.2280

(1,2,.,.) =
 0.01 *
  1.3772  1.3971
  -1.2813 -0.8563

(1,3,.,.) =
 0.01 *
  6.2367  9.2561
   3.5719  5.4744

(1,4,.,.) =
  0.2901  0.2963
  0.2618  0.2771
[ CPUFloatType{1,4,2,2} ]
probs: 0.0001 *
 1.4593  1.0351
  6.6782  4.9104
[ CPUFloatType{2,2} ]

如何禁用此行为并获得最终输出?我试图将其显式转换为浮点数,希望这将导致最终输出被存储/显示,但这也不起作用。

4

1 回答 1

2

根据 libtorch 输出张量的源代码,在存储库中搜索“*”字符串后,发现这个“漂亮打印”是在 aten/src/ATen/core/Formatting.cpp 翻译单元中完成的。刻度和星号在此处添加:

static void printScale(std::ostream & stream, double scale) {
  FormatGuard guard(stream);
  stream << defaultfloat << scale << " *" << std::endl;
}

后来张量的所有坐标除以scale

if(scale != 1) {
  printScale(stream, scale);
}
double* tensor_p = tensor.data_ptr<double>();
for(int64_t i = 0; i < tensor.size(0); i++) {
  stream << std::setw(sz) << tensor_p[i]/scale << std::endl;
}

基于这个翻译单元,这是完全不可配置的。

我想你在这里有两个选择:

  1. 调整功能并最低限度地编辑现有功能以满足您的要求。
  2. 在 Formatting.cpp 中删除(或添加#ifdef<<张量的运算符重载并提供您自己的实现。但是,在构建 libtorch 时,您必须将其链接到包含该方法实现的目标。

但是,这两个选项都需要您更改第 3 方代码,我相信这非常糟糕。

于 2020-08-24T12:08:52.620 回答