我对从 Pytorch -> Onnx -> Tensorflow 转换的 Tensorflow 模型有疑问。问题是转换后的 Tensorflow 模型需要 Pytorch 格式的输入(批量大小、通道数、高度、宽度),但不是 Tensorflow 格式(批量大小、高度、宽度、通道数)。因此,我无法使用该模型进一步处理 Vitis AI。
所以我想问有没有什么方法可以通过使用 Onnx、Tensorflow 1 或其他工具将这种 Pytorch 输入格式转换为 Tensorflow 格式?
我的代码如下:
Pytorch -> Onnx
from hardnet import hardnet
import torch
import onnx
ckpt = torch.load('../hardnet.pth')
model_state_dict = ckpt['model_state_dict']
optimizer_state_dict = ckpt['optimizer_state_dict']
model = hardnet(11)
model.load_state_dict(model_state_dict)
model.eval()
dummy_input = torch.randn(1, 3, 1080, 1920)
input_names = ['input0']
output_names = ['output0']
output_file = 'hardnet.onnx'
torch.onnx.export(model, dummy_input, output_file, verbose=True,
input_names=input_names, output_names=output_names,
opset_version=11, keep_initializers_as_inputs=True)
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
print('Passed Onnx')
Onnx -> Tensorflow 1(使用 Tensorflow 1.15)
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import onnx
from onnx_tf.backend import prepare
output_file = 'hardnet.onnx'
onnx_model = onnx.load(output_file)
output = prepare(onnx_model)
output.export_graph('hardnet.pb')
tf.compat.v1.disable_eager_execution()
def load_pb(path_to_pb: str):
"""From: https://stackoverflow.com/questions/51278213/what-is-the-use-of-a-pb-file-in-tensorflow-and-how-does-it-work
"""
with tf.gfile.GFile(path_to_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
graph = load_pb('hardnet.pb')
input = graph.get_tensor_by_name('input0:0')
output = graph.get_tensor_by_name('output0:0')
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img = cv2.imread('train_0.jpg', cv2.IMREAD_COLOR)
img = cv2.resize(img, (1920, 1080))
img = img/255
img = img - mean
img = img/std
img = np.expand_dims(img, -1)
# To Pytorch format.
img = np.transpose(img, (3, 2, 0, 1))
img = img
with tf.Session(graph=graph) as sess:
pred = sess.run(output, {input: img})