11

Tensorflow-Lite Android 演示使用它提供的原始模型:mobilenet_quant_v1_224.tflite。见:https ://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite

他们还在这里提供了其他预训练的 lite 模型:https ://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md

但是,我从上面的链接下载了一些较小的模型,例如 mobilenet_v1_0.25_224.tflite,并在演示应用程序中将原始模型替换为此模型,只需更改MODEL_PATH = "mobilenet_v1_0.25_224.tflite";. ImageClassifier.java该应用程序崩溃:

12-11 12:52:34.222 17713-17729/? E/AndroidRuntime:致命异常:CameraBackground 进程:android.example.com.tflitecamerademo,PID:17713 java.lang.IllegalArgumentException:无法获取输入尺寸。第 0 个输入应该有 602112 字节,但找到了 150528 字节。在 org.tensorflow.lite.NativeInterpreterWrapper.getInputDims(Native Method) 在 org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:82) 在 org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:

原因似乎是模型所需的输入尺寸是图像尺寸的四倍。所以我修改DIM_BATCH_SIZE = 1DIM_BATCH_SIZE = 4. 现在错误是:

致命异常:CameraBackground 进程:android.example.com.tflitecamerademo,PID:18241 java.lang.IllegalArgumentException:无法将 FLOAT32 类型的 TensorFlowLite 张量转换为 [[B 类型的 Java 对象(与 TensorFlowLite 类型 UINT8 兼容)在 org.tensorflow.lite.Tensor.copyTo(Tensor.java:36) 在 org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:122) 在 org.tensorflow.lite.Interpreter.run(Interpreter.java:93 ) 在 com.example.android.tflitecamerademo。ImageClassifier.classifyFrame(ImageClassifier.java:108) at com.example.android.tflitecamerademo.Camera2BasicFragment.classifyFrame(Camera2BasicFragment.java:663) at com.example.android.tflitecamerademo.Camera2BasicFragment.access$900(Camera2BasicFragment.java:69) com.example.android.tflitecamerademo.Camera2BasicFragment$5.run(Camera2BasicFragment.java:558) at android.os.Handler.handleCallback(Handler.java:751) at android.os.Handler.dispatchMessage(Handler.java:95) at android.os.Looper.loop(Looper.java:154) 在 android.os.HandlerThread.run(HandlerThread.java:61)

我的问题是如何让简化的 MobileNet tflite 模型与 TF-lite Android Demo 一起使用。

(我实际上尝试了其他事情,例如使用提供的工具将 TF 冻结图转换为 TF-lite 模型,甚至使用与https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib中完全相同的示例代码/lite/toco/g3doc/cmdline_examples.md,但转换后的tflite模型在Android Demo中仍然无法使用。)

4

2 回答 2

4

Tensorflow-Lite Android 演示中包含的 ImageClassifier.java 需要一个量化模型。截至目前,只有一种 Mobilenets 模型以量化形式提供:Mobilenet 1.0 224 Quant

要使用其他浮点模型,请从 Tensorflow for Poets TF-Lite 演示源中交换 ImageClassifier.java。这是为浮点模型编写的。 https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/tflite/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java

做一个比较,你会发现在实现上有几个重要的区别。

另一个要考虑的选择是使用 TOCO 将浮点模型转换为量化: https ://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md

于 2017-12-12T23:28:38.077 回答
2

我也遇到了与 Seedling 相同的错误。我为 Mobilenet Float 模型创建了一个新的图像分类器包装器。现在工作正常。您可以直接在图像分类器演示中添加此类,并使用它在 Camera2BasicFragment 中创建分类器

classifier = new ImageClassifierFloatMobileNet(getActivity());

下面是 Mobilenet Float 模型的图像分类器类包装器

    /**
 * This classifier works with the Float MobileNet model.
 */
public class ImageClassifierFloatMobileNet extends ImageClassifier {

  /**
   * An array to hold inference results, to be feed into Tensorflow Lite as outputs.
   * This isn't part of the super class, because we need a primitive array here.
   */
  private float[][] labelProbArray = null;

  private static final int IMAGE_MEAN = 128;
  private static final float IMAGE_STD = 128.0f;

  /**
   * Initializes an {@code ImageClassifier}.
   *
   * @param activity
   */
  public ImageClassifierFloatMobileNet(Activity activity) throws IOException {
    super(activity);
    labelProbArray = new float[1][getNumLabels()];
  }

  @Override
  protected String getModelPath() {
    // you can download this file from
    // https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
//    return "mobilenet_quant_v1_224.tflite";
    return "retrained.tflite";
  }

  @Override
  protected String getLabelPath() {
//    return "labels_mobilenet_quant_v1_224.txt";
    return "retrained_labels.txt";
  }

  @Override
  public int getImageSizeX() {
    return 224;
  }

  @Override
  public int getImageSizeY() {
    return 224;
  }

  @Override
  protected int getNumBytesPerChannel() {
    // the Float model uses a 4 bytes
    return 4;
  }

  @Override
  protected void addPixelValue(int val) {
    imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
    imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
    imgData.putFloat((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
  }

  @Override
  protected float getProbability(int labelIndex) {
    return labelProbArray[0][labelIndex];
  }

  @Override
  protected void setProbability(int labelIndex, Number value) {
    labelProbArray[0][labelIndex] = value.byteValue();
  }

  @Override
  protected float getNormalizedProbability(int labelIndex) {
    return labelProbArray[0][labelIndex];
  }

  @Override
  protected void runInference() {
    tflite.run(imgData, labelProbArray);
  }
}
于 2018-03-14T06:27:21.737 回答