我从这个示例中将经过训练的 LSTM 神经网络从 Matlab 导出到 ONNX。然后我尝试使用ONNX Runtime C#运行这个网络。但是,看起来我做错了什么,网络不记得上一步的状态。
网络应使用以下输出响应输入序列:
输入:[0.258881980200294];输出:[0.311363101005554]
输入:[1.354147904050896];输出:[1.241550326347351]
输入:[ 0.258881980200294, 1.354147904050896 ];输出:[0.311363101005554,1.391810059547424]
前两个示例是仅包含一个元素的序列。最后一个例子是两个元素的序列。这些输出在 Matlab 中计算。我在 Matlab 中重置网络,在每个新序列执行它之间。
然后我尝试使用 ONNX Runtime 运行相同的网络。这是我的 C# 代码:
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections;
using System.Collections.Generic;
namespace OnnxTest
{
public sealed class OnnxRuntimeTest
{
public OnnxRuntimeTest(ILogger logger)
{
this.logger = logger ?? throw new ArgumentNullException(nameof(logger));
}
private const string modelPath = @"E:\Documents\MATLAB\NeuralNetworkExport\onnx_lstm_medic.onnx";
private readonly ILogger logger;
public void Run()
{
using (var session = new InferenceSession(modelPath))
{
// Input values from the example above:
var input1 = GenerateInputValue(0.258881980200294f);
var input2 = GenerateInputValue(1.35414790405090f);
// I create a container to push the first value:
var container = new List<NamedOnnxValue>() { input1 };
//Run the inference
using (var results = session.Run(container))
{
// dump the results
foreach (var r in results)
{
logger.Log(string.Format("Output for {0}", r.Name));
logger.Log(r.AsTensor<float>().GetArrayString());
// Outputs 0,3113631 - as expected
}
}
// The same code to push the second value:
var container2 = new List<NamedOnnxValue>() { input2 };
using (var results = session.Run(container2))
{
// dump the results
foreach (var r in results)
{
logger.Log(string.Format("Output for {0}", r.Name));
logger.Log(r.AsTensor<float>().GetArrayString());
// Outputs 1,24155 - as though this is the first input value
}
}
}
}
private NamedOnnxValue GenerateInputValue(float inputValue)
{
float[] inputData = new float[] { inputValue };
int[] dimensions = new int[] { 1, 1, 1 };
var tensor = new DenseTensor<float>(inputData, dimensions);
return NamedOnnxValue.CreateFromTensor("sequenceinput", tensor);
}
如您所见,第二个会话运行结果为 1,24155,而不是预期值 (1.391810059547424),就好像网络仍处于初始状态一样。看起来我没有保存 LSTM 网络的状态,但我在文档中找不到如何做到这一点。
那么,有谁知道如何让 LSTM 保持其状态?