0

我已经为特定的训练数据集使用 mobilenet 构建了一个模型。在使用测试集测试我的模型时,在 keras (model.h5) 中生成的模型获得了大约 92% 的准确度。然后我使用以下代码将我的模型转换为 tflite:

model = tf.keras.models.load_model('modelos TensorflowLite/MobileNet.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("MobileNet.tflite", "wb").write(tflite_model)

在 python 中使用 tflite 解释器针对同一测试集执行 tflite 模型时,我获得的准确率与使用 keras 模型获得的准确率非常相似,接近 92%。在解释器中用于一种推理的代码:

interpreter = tf.lite.Interpreter(model_path="MobileNet.tflite")
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.allocate_tensors()

    
    # lectura y procesamiento de imagen
    img = cv2.imread("image.jpg")
    new_img = cv2.resize(img, (300, 300))
    new_img = new_img.astype(np.float32)
    new_img /= 255.
    
    # input_details[0]['index'] = the index which accepts the input
    interpreter.set_tensor(input_details[0]['index'], [new_img])
    
    # realizar la prediccion del interprete
    interpreter.invoke()
    
    # output_details[0]['index'] = the index which provides the input
    output_data = interpreter.get_tensor(output_details[0]['index'])
    
    print("For file {}, the output is {}".format(file.stem, output_data))

当我在 android studio 中测试测试套件时出现问题。使用转换为 tflite 的相同模型,针对相同测试集的准确率为 39%。应该提到的是,该模型没有量化。我对 3 个类别中的每一个获得的结果进行了单个图像比较。在这张图片中,该类被正确分类为 keras 和 tflite 模型,但在 android 中没有:

可能性 keras 模型 .h5 tflite py 解释器 tflite 安卓
概率。正确的班级 9.6e-01 9.6e-01 3.2e-6

我的问题不在于将 .h5 模型转换为 .tflite 时精度低。我的问题是 tflite 模型在 python 解释器中可以正常工作,但在 android studio 中实现时非常糟糕。

加载图像的代码:

private TensorImage loadImage(Bitmap bitmap, int sensorOrientation) {
    // Loads bitmap into a TensorImage.
    inputImageBuffer.load(bitmap);

    int noOfRotations = sensorOrientation / 90;
    int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight());

    ImageProcessor imageProcessor = new ImageProcessor.Builder()
            .add(new ResizeWithCropOrPadOp(cropSize, cropSize))
            .add(new ResizeOp(imageResizeX, imageResizeY, ResizeOp.ResizeMethod.BILINEAR))
            .add(new Rot90Op(noOfRotations))
            .add(new NormalizeOp(IMAGE_MEAN, IMAGE_STD))
            .build();
    return imageProcessor.process(inputImageBuffer);
}

执行预测的代码:

inputImageBuffer = loadImage(bitmap, sensorOrientation);
tensorClassifier.run(inputImageBuffer.getBuffer(), probabilityImageBuffer.getBuffer().rewind());

所有要分类的代码(ImageCLassifier.java):

import android.app.Activity;
import android.graphics.Bitmap;
import android.widget.Toast;

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;
import org.tensorflow.lite.support.label.TensorLabel;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class ImageClassifier {

    // Non-Quantized
    private static final float PROBABILITY_MEAN = 0.0f;
    private static final float PROBABILITY_STD = 1.0f;

    private static final float IMAGE_STD = 127.5f;
    private static final float IMAGE_MEAN = 127.5f;

    private static final int MAX_SIZE =3;

    /**
     * Image size along the x axis.
     */
    private final int imageResizeX;
    /**
     * Image size along the y axis.
     */
    private final int imageResizeY;

    /**
     * Labels corresponding to the output of the vision model.
     */
    private final List<String> labels;

    /**
     * An instance of the driver class to run model inference with Tensorflow Lite.
     */
    private final Interpreter tensorClassifier;
    /**
     * Input image TensorBuffer.
     */
    private TensorImage inputImageBuffer;
    /**
     * Output probability TensorBuffer.
     */
    private final TensorBuffer probabilityImageBuffer;
    /**
     * Processer to apply post processing of the output probability.
     */
    private final TensorProcessor probabilityProcessor;

    /**
     * Creates a classifier
     *
     * @param activity the current activity
     * @throws IOException
     */
    public ImageClassifier(Activity activity) throws IOException {
        /*
         * The loaded TensorFlow Lite model.
         */
        MappedByteBuffer classifierModel = FileUtil.loadMappedFile(activity,
                "MobileNet.tflite");
        // Loads labels out from the label file.
        labels = FileUtil.loadLabels(activity, "labels_mobilenet.txt");

        tensorClassifier = new Interpreter(classifierModel, null);

        // Reads type and shape of input and output tensors, respectively. [START]
        int imageTensorIndex = 0; // input
        int probabilityTensorIndex = 0;// output

        int[] inputImageShape = tensorClassifier.getInputTensor(imageTensorIndex).shape();
        DataType inputDataType = tensorClassifier.getInputTensor(imageTensorIndex).dataType();

        int[] outputImageShape = tensorClassifier.getOutputTensor(probabilityTensorIndex).shape();
        DataType outputDataType = tensorClassifier.getOutputTensor(probabilityTensorIndex).dataType();

        imageResizeX = inputImageShape[2];
        imageResizeY = inputImageShape[1];


        // Creates the input tensor.
        inputImageBuffer = new TensorImage(inputDataType);

        // Creates the output tensor and its processor.
        probabilityImageBuffer = TensorBuffer.createFixedSize(outputImageShape, outputDataType);

        // Creates the post processor for the output probability.
        probabilityProcessor = new TensorProcessor.Builder().add(new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD))
                .build();
    }

    /**
     * method runs the inference and returns the classification results
     *
     * @param bitmap            the bitmap of the image
     * @param sensorOrientation orientation of the camera
     * @return classification results
     */
    public List<Recognition> recognizeImage(final Bitmap bitmap, final int sensorOrientation) {
        // Lista con labels y probabilidades de cada clase
        List<Recognition> recognitions = new ArrayList<>();

        inputImageBuffer = loadImage(bitmap, sensorOrientation);
        tensorClassifier.run(inputImageBuffer.getBuffer(), probabilityImageBuffer.getBuffer().rewind()); ///

        // Gets the map of label and probability.
        Map<String, Float> labelledProbability = new TensorLabel(labels,
                probabilityProcessor.process(probabilityImageBuffer)).getMapWithFloatValue();

        int idLabel = 0;
        for (Map.Entry<String, Float> entry : labelledProbability.entrySet()) {
            recognitions.add(new Recognition(String.valueOf(idLabel), entry.getValue()));
            idLabel++;
        }        

        // Lista con probabilidades de cada clase
        List<Float> probabilidades = new ArrayList<>();
        for (Map.Entry<String, Float> entry : labelledProbability.entrySet()) {
            probabilidades.add(entry.getValue());
        }

        Collections.sort(recognitions);

        return recognitions.subList(0, MAX_SIZE);
    }

    /**
     * loads the image into tensor input buffer and apply pre processing steps
     *
     * @param bitmap            the bit map to be loaded
     * @param sensorOrientation the sensor orientation
     * @return the image loaded tensor input buffer
     */
    private TensorImage loadImage(Bitmap bitmap, int sensorOrientation) {
        // Loads bitmap into a TensorImage.
        inputImageBuffer.load(bitmap);

        int noOfRotations = sensorOrientation / 90;
        int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight());

        // pre processing steps are applied here
        ImageProcessor imageProcessor = new ImageProcessor.Builder()
                .add(new ResizeWithCropOrPadOp(cropSize, cropSize))
                .add(new ResizeOp(imageResizeX, imageResizeY, ResizeOp.ResizeMethod.BILINEAR))
                .add(new Rot90Op(noOfRotations))
                .add(new NormalizeOp(IMAGE_MEAN, IMAGE_STD))
                .build();
        return imageProcessor.process(inputImageBuffer);
    }

    /**
     * An immutable result returned by a Classifier describing what was recognized.
     */
    public class Recognition implements Comparable {
        /**
         * Display name for the recognition.
         */
        private String name;
        /**
         * A sortable score for how good the recognition is relative to others. Higher should be better.
         */
        private float confidence;

        public Recognition() {
        }

        public Recognition(String name, float confidence) {
            this.name = name;
            this.confidence = confidence;
        }

        public String getName() {
            return name;
        }

        public void setName(String name) {
            this.name = name;
        }

        public float getConfidence() {
            return confidence;
        }

        public void setConfidence(float confidence) {
            this.confidence = confidence;
        }

        @Override
        public String toString() {
            return "Recognition{" +
                    "name='" + name + '\'' +
                    ", confidence=" + confidence +
                    '}';
        }

        @Override
        public int compareTo(Object o) {
            return Float.compare(((Recognition) o).confidence, this.confidence);
        }
    }


}
4

0 回答 0