4

我正在尝试在 C++ 中的 jit 跟踪模型上运行推理,目前我在 Python 中获得的输出与我在 C++ 中获得的输出不同。

最初我认为这是由 jit 模型本身引起的,但现在我不这么认为,因为我发现 C++ 代码中的输入张量存在一些小的偏差。我相信我按照文档的指示做了所有事情,这样也可以在torch::from_blob. 我不知道!

因此,为了确定是哪种情况,这里是 Python 和 C++ 中的代码片段以及用于测试它的示例输入。

这是示例图像:

对于 Pytorch,运行以下代码片段:

import cv2
import torch
from PIL import Image 
import math
import numpy as np

img = Image.open('D:/Codes/imgs/profile6.jpg')
width, height = img.size
scale = 0.6
sw, sh = math.ceil(width * scale), math.ceil(height * scale)
img = img.resize((sw, sh), Image.BILINEAR)
img = np.asarray(img, 'float32')

# preprocess it 
img = img.transpose((2, 0, 1))
img = np.expand_dims(img, 0)
img = (img - 127.5) * 0.0078125
img = torch.from_numpy(img)

对于 C++:

#include <iostream>
#include <torch/torch.h>
#include <torch/script.h>
using namespace torch::indexing;

#include <opencv2/core.hpp>
#include<opencv2/imgproc/imgproc.hpp>
#include<opencv2/highgui/highgui.hpp>

void test15()
{
    std::string pnet_path = "D:/Codes//MTCNN/pnet.jit"; 
    cv::Mat img = cv::imread("D:/Codes/imgs/profile6.jpg");
    int width = img.cols;
    int height = img.rows;
    float scale = 0.6f;
    int sw = int(std::ceil(width * scale));
    int sh = int(std::ceil(height * scale));

    //cv::Mat img;
    cv::resize(img, img, cv::Size(sw, sh), 0, 0, 1);

    auto tensor_image = torch::from_blob(img.data, { img.rows, img.cols, img.channels() }, at::kByte);
    tensor_image = tensor_image.permute({ 2,0,1 });
    tensor_image.unsqueeze_(0);
    tensor_image = tensor_image.toType(c10::kFloat).sub(127.5).mul(0.0078125);
    tensor_image.to(c10::DeviceType::CPU);
}

### Input comparison : 
and here are the tensor values both in Python and C++ 
Pytorch input (`img[:, :, :10, :10]`):

```python
img: tensor([[
    [[0.3555,  0.3555,  0.3477,  0.3555,  0.3711,  0.3945,  0.3945,  0.3867,  0.3789,  0.3789],
    [ 0.3477,  0.3555,  0.3555,  0.3555,  0.3555,  0.3555,  0.3555,  0.3477,  0.3398,  0.3398],
    [ 0.3320,  0.3242,  0.3320,  0.3242,  0.3320,  0.3398,  0.3398,  0.3242,  0.3164,  0.3242],
    [ 0.2852,  0.2930,  0.2852,  0.2852,  0.2930,  0.2930,  0.2930,  0.2852,  0.2773,  0.2773],
    [ 0.2539,  0.2617,  0.2539,  0.2617,  0.2539,  0.2148,  0.2148,  0.2148,  0.2070,  0.2070],
    [ 0.1914,  0.1914,  0.1836,  0.1836,  0.1758,  0.1523,  0.1367,  0.1211,  0.0977,  0.0898],
    [ 0.1367,  0.1211,  0.0977,  0.0820,  0.0742,  0.0586,  0.0273,  -0.0195, -0.0742, -0.0820],
    [-0.0039, -0.0273, -0.0508, -0.0664, -0.0898, -0.1211, -0.1367, -0.1523, -0.1758, -0.1758],
    [-0.2070, -0.2070, -0.2148, -0.2227, -0.2148, -0.1992, -0.1992, -0.1836, -0.1680, -0.1680],
    [-0.2539, -0.2461, -0.2383, -0.2305, -0.2227, -0.1914, -0.1836, -0.1758, -0.1680, -0.1602]],

    [[0.8398,  0.8398,  0.8320,  0.8242,  0.8320,  0.8477,  0.8398, 0.8320,  0.8164,  0.8164],
    [ 0.8320,  0.8242,  0.8164,  0.8164,  0.8086,  0.8008,  0.7930, 0.7852,  0.7695,  0.7695],
    [ 0.7852,  0.7852,  0.7773,  0.7695,  0.7695,  0.7617,  0.7539, 0.7383,  0.7305,  0.7148],
    [ 0.7227,  0.7070,  0.7070,  0.6992,  0.6914,  0.6836,  0.6836, 0.6680,  0.6523,  0.6367],
    [ 0.6289,  0.6211,  0.6211,  0.6211,  0.6055,  0.5586,  0.5508, 0.5352,  0.5273,  0.5039],
    [ 0.4805,  0.4727,  0.4648,  0.4648,  0.4570,  0.4180,  0.3945, 0.3633,  0.3477,  0.3164],
    [ 0.3555,  0.3398,  0.3086,  0.2930,  0.2695,  0.2461,  0.2070, 0.1523,  0.1055,  0.0820],
    [ 0.1367,  0.1133,  0.0820,  0.0508,  0.0273, -0.0117, -0.0352, -0.0508, -0.0820, -0.0898],
    [-0.1211, -0.1289, -0.1445, -0.1602, -0.1602, -0.1523, -0.1523, -0.1367, -0.1367, -0.1289],
    [-0.2070, -0.1992, -0.1992, -0.1992, -0.1992, -0.1680, -0.1680, -0.1602, -0.1523, -0.1445]],

    [[0.9492,  0.9414,  0.9336,  0.9180,  0.9180,  0.9336,  0.9258, 0.9023,  0.8867,  0.9023],
    [ 0.9258,  0.9258,  0.9102,  0.9023,  0.8945,  0.8789,  0.8633, 0.8477,  0.8320,  0.8398],
    [ 0.8711,  0.8633,  0.8555,  0.8477,  0.8320,  0.8242,  0.8086, 0.7930,  0.7852,  0.7773],
    [ 0.7852,  0.7773,  0.7617,  0.7539,  0.7461,  0.7305,  0.7148, 0.6992,  0.6914,  0.6836],
    [ 0.6758,  0.6680,  0.6602,  0.6602,  0.6367,  0.5820,  0.5742, 0.5508,  0.5430,  0.5273],
    [ 0.5117,  0.5117,  0.4961,  0.4883,  0.4727,  0.4336,  0.4102, 0.3711,  0.3477,  0.3242],
    [ 0.3867,  0.3711,  0.3398,  0.3164,  0.2930,  0.2539,  0.2148, 0.1523,  0.1055,  0.0820],
    [ 0.1680,  0.1445,  0.1055,  0.0742,  0.0352, -0.0039, -0.0273, -0.0586, -0.0820, -0.0898],
    [-0.0898, -0.0977, -0.1211, -0.1367, -0.1445, -0.1445, -0.1445, -0.1445, -0.1445, -0.1445],
    [-0.1758, -0.1680, -0.1680, -0.1680, -0.1680, -0.1523, -0.1523, -0.1602, -0.1602, -0.1523]]]])

C++/Libtorch 张量值 ( img.index({Slice(), Slice(), Slice(None, 10), Slice(None, 10)});):

img: (1,1,.,.) =
  0.3555  0.3555  0.3555  0.3555  0.3555  0.4023  0.3945  0.3867  0.3789  0.3789
  0.3633  0.3633  0.3555  0.3555  0.3555  0.3555  0.3477  0.3555  0.3398  0.3398
  0.3398  0.3320  0.3320  0.3242  0.3398  0.3320  0.3398  0.3242  0.3242  0.3242
  0.2930  0.2930  0.2852  0.2773  0.2852  0.2930  0.2852  0.2852  0.2773  0.2852
  0.2695  0.2695  0.2617  0.2773  0.2695  0.2227  0.2227  0.2227  0.2148  0.2148
  0.1914  0.1914  0.1914  0.1914  0.1914  0.1602  0.1445  0.1289  0.1055  0.0977
  0.1289  0.1133  0.0820  0.0742  0.0586  0.0586  0.0195 -0.0273 -0.0820 -0.0898
  0.0039 -0.0195 -0.0508 -0.0664 -0.0820 -0.1289 -0.1445 -0.1602 -0.1836 -0.1836
 -0.2070 -0.2148 -0.2227 -0.2383 -0.2305 -0.2070 -0.2070 -0.1914 -0.1836 -0.1758
 -0.2539 -0.2461 -0.2461 -0.2383 -0.2305 -0.1914 -0.1914 -0.1758 -0.1680 -0.1602

(1,2,.,.) =
  0.8398  0.8398  0.8242  0.8164  0.8242  0.8555  0.8398  0.8320  0.8242  0.8242
  0.8320  0.8320  0.8242  0.8242  0.8086  0.8008  0.7930  0.7773  0.7695  0.7617
  0.7930  0.7852  0.7773  0.7695  0.7695  0.7695  0.7539  0.7461  0.7305  0.7227
  0.7070  0.7070  0.6992  0.6992  0.6914  0.6836  0.6758  0.6602  0.6523  0.6367
  0.6367  0.6367  0.6289  0.6289  0.6211  0.5664  0.5586  0.5430  0.5352  0.5117
  0.4805  0.4805  0.4805  0.4648  0.4727  0.4258  0.4023  0.3711  0.3555  0.3320
  0.3398  0.3320  0.3008  0.2773  0.2617  0.2461  0.1992  0.1445  0.0898  0.0586
  0.1367  0.1211  0.0898  0.0508  0.0273 -0.0195 -0.0352 -0.0664 -0.0898 -0.1055
 -0.1211 -0.1289 -0.1367 -0.1602 -0.1602 -0.1523 -0.1523 -0.1445 -0.1445 -0.1367
 -0.2148 -0.2070 -0.2070 -0.2070 -0.1992 -0.1680 -0.1680 -0.1602 -0.1523 -0.1445

(1,3,.,.) =
  0.9414  0.9414  0.9336  0.9180  0.9102  0.9336  0.9258  0.9023  0.8945  0.9023
  0.9180  0.9180  0.9102  0.9102  0.8945  0.8711  0.8633  0.8555  0.8242  0.8477
  0.8711  0.8711  0.8633  0.8477  0.8320  0.8164  0.8164  0.7930  0.7852  0.7852
  0.7773  0.7773  0.7539  0.7461  0.7305  0.7148  0.7070  0.6992  0.6836  0.6758
  0.6836  0.6836  0.6758  0.6680  0.6445  0.5898  0.5820  0.5586  0.5508  0.5352
  0.5273  0.5195  0.5117  0.4883  0.4883  0.4414  0.4102  0.3789  0.3633  0.3398
  0.3867  0.3633  0.3320  0.3008  0.2695  0.2539  0.2070  0.1445  0.0898  0.0664
  0.1836  0.1523  0.1133  0.0742  0.0352 -0.0117 -0.0352 -0.0664 -0.0898 -0.1055
 -0.0820 -0.0977 -0.1211 -0.1367 -0.1445 -0.1445 -0.1445 -0.1367 -0.1445 -0.1445
 -0.1758 -0.1758 -0.1758 -0.1758 -0.1758 -0.1602 -0.1523 -0.1680 -0.1602 -0.1602

[ CPUFloatType{1,3,10,10} ]

顺便说一下,这些是标准化/预处理之前的张量值:

Python:

img.shape: (3, 101, 180)
img: [
 [[173. 173. 172. 173. 175.]
  [172. 173. 173. 173. 173.]
  [170. 169. 170. 169. 170.]
  [164. 165. 164. 164. 165.]
  [160. 161. 160. 161. 160.]]

 [[235. 235. 234. 233. 234.]
  [234. 233. 232. 232. 231.]
  [228. 228. 227. 226. 226.]
  [220. 218. 218. 217. 216.]
  [208. 207. 207. 207. 205.]]

 [[249. 248. 247. 245. 245.]
  [246. 246. 244. 243. 242.]
  [239. 238. 237. 236. 234.]
  [228. 227. 225. 224. 223.]
  [214. 213. 212. 212. 209.]]]

共产党:

img.shape: [1, 3, 101, 180]
img: (1,1,.,.) =
  173  173  173  173  173
  174  174  173  173  173
  171  170  170  169  171
  165  165  164  163  164
  162  162  161  163  162

(1,2,.,.) =
  235  235  233  232  233
  234  234  233  233  231
  229  228  227  226  226
  218  218  217  217  216
  209  209  208  208  207

(1,3,.,.) =
  248  248  247  245  244
  245  245  244  244  242
  239  239  238  236  234
  227  227  224  223  221
  215  215  214  213  210
[ CPUByteType{1,3,5,5} ]

如您所见,乍一看,它们可能看起来相同,但仔细观察,您会发现输入中有许多小的偏差!如何避免这些更改,并获得 C++ 中的确切值?

我想知道是什么导致了这种奇怪的现象发生!

4

1 回答 1

2

明确表示这确实是一个输入问题,更具体地说,这是因为图像首先PIL.Image.open在 Python 中读取,然后更改为numpy数组。如果使用 读取图像OpenCV,那么所有输入方式在 Python 和 C++ 中都是相同的。

更多解释

但是,在我的具体情况下,使用 OpenCV 图像会导致最终结果发生微小变化。最小化这种变化/差异的唯一方法是,当我制作 Opencv 图像灰度并将其馈送到网络时,PIL 输入和 opencv 输入都具有几乎相同的输出。

这是两个示例,pil 图像是 bgr,opencv 处于灰度模式:您需要将它们保存在磁盘上并看到它们几乎相同(左侧是 cv_image,右侧是 pil_image):

img62_cv img62_pil

但是,如果我只是不将 opencv 图像转换为灰度模式(并返回 bgr 以获得 3 个通道),这就是它的外观(左侧是 cv_image,右侧是 pil_image):

img62_cv img62_pil

更新

事实证明这又与输入有关。我们有细微差别的原因是由于模型是在 rgb 图像上训练的,因此通道顺序很重要。使用 PIL 图像时,不同方法会来回发生一些转换,因此它会导致整个事情变得一团糟,就像您之前在上面读到的那样。

长话短说,从转换cv::Mat为 atorch::Tensor或反之转换没有任何问题,问题在于在 Python 和 C++ 中创建图像并将其馈送到网络的方式不同。当 Python 和 C++ 后端都使用 OpenCV 处理图像时,它们的输出和结果匹配 100%。

于 2020-08-24T12:50:09.050 回答