我已经训练了一个神经网络,它以 4 个浮点值作为输入,并为四个类标签返回一个热编码输出。
例如,{2,12,30,4} -> {0, 0, 1, 0}
训练后的模型生成并保存在 .pb 文件中。然后将该模型导入到我的 android 应用程序的资产文件夹中:
inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "tensorflow_lite_xor_nn.pb");
我有以下功能:
private float[] predict(float[] input){
float output[] = new float[4];
inferenceInterface.feed("dense_1_input", input, 4, input.length);
inferenceInterface.run(new String[]{"dense_2/Sigmoid"});
inferenceInterface.fetch("dense_2/Sigmoid", output);
return output;
}
但我收到此错误:
java.lang.IllegalArgumentException: 具有 4 个元素的缓冲区与形状为 [4, 4] 的张量不兼容