1

我正在使用 Tensorflow.Lite.Support 函数进行模型推理,该模型接受两个输入并以图像的形式提供输出。第一个输入是 RGB 图像,而第二个图像是单通道图像。当我运行应用程序进行推理时,出现以下错误:

无法在 1392640 字节的 TensorFlowLite 缓冲区和 4177920 字节的 Java 缓冲区之间进行转换。

我在下面附上了我的代码片段:

protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        Toolbar toolbar = findViewById(R.id.toolbar);
        setSupportActionBar(toolbar);

        FloatingActionButton fab = findViewById(R.id.fab);
        fab.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                Snackbar.make(view, "Replace with your own action", Snackbar.LENGTH_LONG)
                        .setAction("Action", null).show();
            }
        });

        try {
            tflite = new Interpreter(loadModelFile(MainActivity.this, "converted_model.tflite"));


            Bitmap Image = getBitmapFromAsset(this,"001.png");
            Bitmap Mask = getBitmapFromAsset(this,"tmp_mask.png");

            int imageTensorIndex = 0;
            int[] imageShape = tflite.getInputTensor(imageTensorIndex).shape(); // {1, height, width, 3}
            int imageSizeY = imageShape[1];
            int imageSizeX = imageShape[2];
            DataType imageDataType = tflite.getInputTensor(imageTensorIndex).dataType();

            TensorImage inputImageBuffer = new TensorImage(imageDataType);
            inputImageBuffer.load(Image);

            int imageTensorIndex1 = 1;
            int[] imageShape1 = tflite.getInputTensor(imageTensorIndex1).shape(); // {1, height, width, 3}
            int imageSizeY1 = imageShape1[1];
            int imageSizeX1 = imageShape1[2];
            DataType imageDataType1 = tflite.getInputTensor(imageTensorIndex1).dataType();

            TensorImage inputImageBuffer1 = new TensorImage(imageDataType1);
            inputImageBuffer1.load(Mask);


            int OutputTensorIndex = 0;
            int[] OutputShape =
                    tflite.getOutputTensor(OutputTensorIndex).shape(); // {1, NUM_CLASSES}
            DataType OutputDataType = tflite.getOutputTensor(OutputTensorIndex).dataType();

            TensorBuffer outputBuffer = TensorBuffer.createFixedSize(OutputShape, OutputDataType);

            Object[] Inputs = {inputImageBuffer.getBuffer(),inputImageBuffer1.getBuffer()};

            Map<Integer, Object> outputs = new HashMap<>();
            outputs.put(0,outputBuffer);
            tflite.runForMultipleInputsOutputs(Inputs, outputs);

            Toast.makeText(this,"Working",Toast.LENGTH_LONG).show();

        } catch (IOException e) {
            Toast.makeText(this,"Failed",Toast.LENGTH_LONG).show();
            e.printStackTrace();
        }
    }
4

0 回答 0