我已经为特定的训练数据集使用 mobilenet 构建了一个模型。在使用测试集测试我的模型时,在 keras (model.h5) 中生成的模型获得了大约 92% 的准确度。然后我使用以下代码将我的模型转换为 tflite:
model = tf.keras.models.load_model('modelos TensorflowLite/MobileNet.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("MobileNet.tflite", "wb").write(tflite_model)
在 python 中使用 tflite 解释器针对同一测试集执行 tflite 模型时,我获得的准确率与使用 keras 模型获得的准确率非常相似,接近 92%。在解释器中用于一种推理的代码:
interpreter = tf.lite.Interpreter(model_path="MobileNet.tflite")
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.allocate_tensors()
# lectura y procesamiento de imagen
img = cv2.imread("image.jpg")
new_img = cv2.resize(img, (300, 300))
new_img = new_img.astype(np.float32)
new_img /= 255.
# input_details[0]['index'] = the index which accepts the input
interpreter.set_tensor(input_details[0]['index'], [new_img])
# realizar la prediccion del interprete
interpreter.invoke()
# output_details[0]['index'] = the index which provides the input
output_data = interpreter.get_tensor(output_details[0]['index'])
print("For file {}, the output is {}".format(file.stem, output_data))
当我在 android studio 中测试测试套件时出现问题。使用转换为 tflite 的相同模型,针对相同测试集的准确率为 39%。应该提到的是,该模型没有量化。我对 3 个类别中的每一个获得的结果进行了单个图像比较。在这张图片中,该类被正确分类为 keras 和 tflite 模型,但在 android 中没有:
可能性 | keras 模型 .h5 | tflite py 解释器 | tflite 安卓 |
---|---|---|---|
概率。正确的班级 | 9.6e-01 | 9.6e-01 | 3.2e-6 |
我的问题不在于将 .h5 模型转换为 .tflite 时精度低。我的问题是 tflite 模型在 python 解释器中可以正常工作,但在 android studio 中实现时非常糟糕。
加载图像的代码:
private TensorImage loadImage(Bitmap bitmap, int sensorOrientation) {
// Loads bitmap into a TensorImage.
inputImageBuffer.load(bitmap);
int noOfRotations = sensorOrientation / 90;
int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight());
ImageProcessor imageProcessor = new ImageProcessor.Builder()
.add(new ResizeWithCropOrPadOp(cropSize, cropSize))
.add(new ResizeOp(imageResizeX, imageResizeY, ResizeOp.ResizeMethod.BILINEAR))
.add(new Rot90Op(noOfRotations))
.add(new NormalizeOp(IMAGE_MEAN, IMAGE_STD))
.build();
return imageProcessor.process(inputImageBuffer);
}
执行预测的代码:
inputImageBuffer = loadImage(bitmap, sensorOrientation);
tensorClassifier.run(inputImageBuffer.getBuffer(), probabilityImageBuffer.getBuffer().rewind());
所有要分类的代码(ImageCLassifier.java):
import android.app.Activity;
import android.graphics.Bitmap;
import android.widget.Toast;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;
import org.tensorflow.lite.support.label.TensorLabel;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class ImageClassifier {
// Non-Quantized
private static final float PROBABILITY_MEAN = 0.0f;
private static final float PROBABILITY_STD = 1.0f;
private static final float IMAGE_STD = 127.5f;
private static final float IMAGE_MEAN = 127.5f;
private static final int MAX_SIZE =3;
/**
* Image size along the x axis.
*/
private final int imageResizeX;
/**
* Image size along the y axis.
*/
private final int imageResizeY;
/**
* Labels corresponding to the output of the vision model.
*/
private final List<String> labels;
/**
* An instance of the driver class to run model inference with Tensorflow Lite.
*/
private final Interpreter tensorClassifier;
/**
* Input image TensorBuffer.
*/
private TensorImage inputImageBuffer;
/**
* Output probability TensorBuffer.
*/
private final TensorBuffer probabilityImageBuffer;
/**
* Processer to apply post processing of the output probability.
*/
private final TensorProcessor probabilityProcessor;
/**
* Creates a classifier
*
* @param activity the current activity
* @throws IOException
*/
public ImageClassifier(Activity activity) throws IOException {
/*
* The loaded TensorFlow Lite model.
*/
MappedByteBuffer classifierModel = FileUtil.loadMappedFile(activity,
"MobileNet.tflite");
// Loads labels out from the label file.
labels = FileUtil.loadLabels(activity, "labels_mobilenet.txt");
tensorClassifier = new Interpreter(classifierModel, null);
// Reads type and shape of input and output tensors, respectively. [START]
int imageTensorIndex = 0; // input
int probabilityTensorIndex = 0;// output
int[] inputImageShape = tensorClassifier.getInputTensor(imageTensorIndex).shape();
DataType inputDataType = tensorClassifier.getInputTensor(imageTensorIndex).dataType();
int[] outputImageShape = tensorClassifier.getOutputTensor(probabilityTensorIndex).shape();
DataType outputDataType = tensorClassifier.getOutputTensor(probabilityTensorIndex).dataType();
imageResizeX = inputImageShape[2];
imageResizeY = inputImageShape[1];
// Creates the input tensor.
inputImageBuffer = new TensorImage(inputDataType);
// Creates the output tensor and its processor.
probabilityImageBuffer = TensorBuffer.createFixedSize(outputImageShape, outputDataType);
// Creates the post processor for the output probability.
probabilityProcessor = new TensorProcessor.Builder().add(new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD))
.build();
}
/**
* method runs the inference and returns the classification results
*
* @param bitmap the bitmap of the image
* @param sensorOrientation orientation of the camera
* @return classification results
*/
public List<Recognition> recognizeImage(final Bitmap bitmap, final int sensorOrientation) {
// Lista con labels y probabilidades de cada clase
List<Recognition> recognitions = new ArrayList<>();
inputImageBuffer = loadImage(bitmap, sensorOrientation);
tensorClassifier.run(inputImageBuffer.getBuffer(), probabilityImageBuffer.getBuffer().rewind()); ///
// Gets the map of label and probability.
Map<String, Float> labelledProbability = new TensorLabel(labels,
probabilityProcessor.process(probabilityImageBuffer)).getMapWithFloatValue();
int idLabel = 0;
for (Map.Entry<String, Float> entry : labelledProbability.entrySet()) {
recognitions.add(new Recognition(String.valueOf(idLabel), entry.getValue()));
idLabel++;
}
// Lista con probabilidades de cada clase
List<Float> probabilidades = new ArrayList<>();
for (Map.Entry<String, Float> entry : labelledProbability.entrySet()) {
probabilidades.add(entry.getValue());
}
Collections.sort(recognitions);
return recognitions.subList(0, MAX_SIZE);
}
/**
* loads the image into tensor input buffer and apply pre processing steps
*
* @param bitmap the bit map to be loaded
* @param sensorOrientation the sensor orientation
* @return the image loaded tensor input buffer
*/
private TensorImage loadImage(Bitmap bitmap, int sensorOrientation) {
// Loads bitmap into a TensorImage.
inputImageBuffer.load(bitmap);
int noOfRotations = sensorOrientation / 90;
int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight());
// pre processing steps are applied here
ImageProcessor imageProcessor = new ImageProcessor.Builder()
.add(new ResizeWithCropOrPadOp(cropSize, cropSize))
.add(new ResizeOp(imageResizeX, imageResizeY, ResizeOp.ResizeMethod.BILINEAR))
.add(new Rot90Op(noOfRotations))
.add(new NormalizeOp(IMAGE_MEAN, IMAGE_STD))
.build();
return imageProcessor.process(inputImageBuffer);
}
/**
* An immutable result returned by a Classifier describing what was recognized.
*/
public class Recognition implements Comparable {
/**
* Display name for the recognition.
*/
private String name;
/**
* A sortable score for how good the recognition is relative to others. Higher should be better.
*/
private float confidence;
public Recognition() {
}
public Recognition(String name, float confidence) {
this.name = name;
this.confidence = confidence;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public float getConfidence() {
return confidence;
}
public void setConfidence(float confidence) {
this.confidence = confidence;
}
@Override
public String toString() {
return "Recognition{" +
"name='" + name + '\'' +
", confidence=" + confidence +
'}';
}
@Override
public int compareTo(Object o) {
return Float.compare(((Recognition) o).confidence, this.confidence);
}
}
}