2

这个MainActivity.java是为量化模型编写的,我正在尝试使用非量化模型。

在进行此处提到的更改后此处MainActivity.java 我的代码是

public class MainActivity extends AppCompatActivity implements AdapterView.OnItemSelectedListener {
    private static final String TAG = "MainActivity";
    private Button mRun;
    private ImageView mImageView;
    private Bitmap mSelectedImage;
    private GraphicOverlay mGraphicOverlay;
    // Max width (portrait mode)
    private Integer mImageMaxWidth;
    // Max height (portrait mode)
    private Integer mImageMaxHeight;
    private final String[] mFilePaths =
            new String[]{"mountain.jpg", "tennis.jpg","96580.jpg"};
    /**
     * Name of the model file hosted with Firebase.
     */
    private static final String HOSTED_MODEL_NAME = "mobilenet_v1_224_quant";
    private static final String LOCAL_MODEL_ASSET = "retrained_graph_mobilenet_1_224.tflite";
    /**
     * Name of the label file stored in Assets.
     */
    private static final String LABEL_PATH = "labels.txt";
    /**
     * Number of results to show in the UI.
     */
    private static final int RESULTS_TO_SHOW = 3;
    /**
     * Dimensions of inputs.
     */
    private static final int DIM_BATCH_SIZE = 1;
    private static final int DIM_PIXEL_SIZE = 3;
    private static final int DIM_IMG_SIZE_X = 224;
    private static final int DIM_IMG_SIZE_Y = 224;
    private static final int IMAGE_MEAN = 128;
    private static final float IMAGE_STD = 128.0f;
    /**
     * Labels corresponding to the output of the vision model.
     */
    private List<String> mLabelList;

    private final PriorityQueue<Map.Entry<String, Float>> sortedLabels =
            new PriorityQueue<>(
                    RESULTS_TO_SHOW,
                    new Comparator<Map.Entry<String, Float>>() {
                        @Override
                        public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float>
                                o2) {
                            return (o1.getValue()).compareTo(o2.getValue());
                        }
                    });
    /* Preallocated buffers for storing image data. */
    private final int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
    /**
     * An instance of the driver class to run model inference with Firebase.
     */
    private FirebaseModelInterpreter mInterpreter;
    /**
     * Data configuration of input & output data of model.
     */
    private FirebaseModelInputOutputOptions mDataOptions;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        mGraphicOverlay = findViewById(R.id.graphic_overlay);
        mImageView = findViewById(R.id.image_view);

        Spinner dropdown = findViewById(R.id.spinner);
        List<String> items = new ArrayList<>();
        for (int i = 0; i < mFilePaths.length; i++) {
            items.add("Image " + (i + 1));
        }

        ArrayAdapter<String> adapter = new ArrayAdapter<>(this, android.R.layout
                .simple_spinner_dropdown_item, items);
        dropdown.setAdapter(adapter);
        dropdown.setOnItemSelectedListener(this);

        mLabelList = loadLabelList(this);
        mRun = findViewById(R.id.button_run);
        mRun.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View v) {
                runModelInference();
            }
        });

        int[] inputDims = {DIM_BATCH_SIZE, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y, DIM_PIXEL_SIZE};
        int[] outputDims = {DIM_BATCH_SIZE, mLabelList.size()};
        try {
            mDataOptions =
                    new FirebaseModelInputOutputOptions.Builder()
                            .setInputFormat(0, FirebaseModelDataType.FLOAT32, inputDims)
                            .setOutputFormat(0, FirebaseModelDataType.FLOAT32, outputDims)
                            .build();
            FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions
                    .Builder()
                    .requireWifi()
                    .build();
            FirebaseLocalModelSource localModelSource =
                    new FirebaseLocalModelSource.Builder("asset")
                            .setAssetFilePath(LOCAL_MODEL_ASSET).build();

            FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder
                    (HOSTED_MODEL_NAME)
                    .enableModelUpdates(true)
                    .setInitialDownloadConditions(conditions)
                    .setUpdatesDownloadConditions(conditions)  // You could also specify
                    // different conditions
                    // for updates
                    .build();
            FirebaseModelManager manager = FirebaseModelManager.getInstance();
            manager.registerLocalModelSource(localModelSource);
            manager.registerCloudModelSource(cloudSource);
            FirebaseModelOptions modelOptions =
                    new FirebaseModelOptions.Builder()
                            .setCloudModelName(HOSTED_MODEL_NAME)
                            .setLocalModelName("asset")
                            .build();
            mInterpreter = FirebaseModelInterpreter.getInstance(modelOptions);
        } catch (FirebaseMLException e) {
            showToast("Error while setting up the model");
            e.printStackTrace();
        }
    }

    private void runModelInference() {
        if (mInterpreter == null) {
            Log.e(TAG, "Image classifier has not been initialized; Skipped.");
            return;
        }
        // Create input data.
        ByteBuffer imgData = convertBitmapToByteBuffer(mSelectedImage, mSelectedImage.getWidth(),
                mSelectedImage.getHeight());

        try {
            FirebaseModelInputs inputs = new FirebaseModelInputs.Builder().add(imgData).build();
            // Here's where the magic happens!!
            mInterpreter
                    .run(inputs, mDataOptions)
                    .addOnFailureListener(new OnFailureListener() {
                        @Override
                        public void onFailure(@NonNull Exception e) {
                            e.printStackTrace();
                            showToast("Error running model inference");
                        }
                    })
                    .continueWith(
                            new Continuation<FirebaseModelOutputs, List<String>>() {
                                @Override
                                public List<String> then(Task<FirebaseModelOutputs> task) {
                                    float[][] labelProbArray = task.getResult()
                                            .<float[][]>getOutput(0);
                                    List<String> topLabels = getTopLabels(labelProbArray);
                                    mGraphicOverlay.clear();
                                    GraphicOverlay.Graphic labelGraphic = new LabelGraphic
                                            (mGraphicOverlay, topLabels);
                                    mGraphicOverlay.add(labelGraphic);
                                    return topLabels;
                                }
                            });
        } catch (FirebaseMLException e) {
            e.printStackTrace();
            showToast("Error running model inference");
        }

    }

    /**
     * Gets the top labels in the results.
     */
    private synchronized List<String> getTopLabels(float[][] labelProbArray) {
        for (int i = 0; i < mLabelList.size(); ++i) {
            sortedLabels.add(
                    new AbstractMap.SimpleEntry<>(mLabelList.get(i), (labelProbArray[0][i] )));
            if (sortedLabels.size() > RESULTS_TO_SHOW) {
                sortedLabels.poll();
            }
        }
        List<String> result = new ArrayList<>();
        final int size = sortedLabels.size();
        for (int i = 0; i < size; ++i) {
            Map.Entry<String, Float> label = sortedLabels.poll();
            result.add(label.getKey() + ":" + label.getValue());
        }
        Log.d(TAG, "labels: " + result.toString());
        return result;
    }

    /**
     * Reads label list from Assets.
     */
    private List<String> loadLabelList(Activity activity) {
        List<String> labelList = new ArrayList<>();
        try (BufferedReader reader =
                     new BufferedReader(new InputStreamReader(activity.getAssets().open
                             (LABEL_PATH)))) {
            String line;
            while ((line = reader.readLine()) != null) {
                labelList.add(line);
            }
        } catch (IOException e) {
            Log.e(TAG, "Failed to read label list.", e);
        }
        return labelList;
    }

    /**
     * Writes Image data into a {@code ByteBuffer}.
     */
    private synchronized ByteBuffer convertBitmapToByteBuffer(
            Bitmap bitmap, int width, int height) {
        ByteBuffer imgData =
                ByteBuffer.allocateDirect(
                       4*DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);

        imgData.order(ByteOrder.nativeOrder());
        Bitmap scaledBitmap = Bitmap.createScaledBitmap(bitmap, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y,
                true);
        imgData.rewind();
        scaledBitmap.getPixels(intValues, 0, scaledBitmap.getWidth(), 0, 0,
                scaledBitmap.getWidth(), scaledBitmap.getHeight());
        // Convert the image to int points.
        int pixel = 0;
        for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
            for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
                final int val = intValues[pixel++];
                imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
                imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
                imgData.putFloat(((val & 0xFF)-IMAGE_MEAN)/IMAGE_STD);


            }
        }
        return imgData;
    }

    private void showToast(String message) {
        Toast.makeText(getApplicationContext(), message, Toast.LENGTH_SHORT).show();
    }

    public void onItemSelected(AdapterView<?> parent, View v, int position, long id) {
        mGraphicOverlay.clear();
        mSelectedImage = getBitmapFromAsset(this, mFilePaths[position]);
        if (mSelectedImage != null) {
            // Get the dimensions of the View
            Pair<Integer, Integer> targetedSize = getTargetedWidthHeight();

            int targetWidth = targetedSize.first;
            int maxHeight = targetedSize.second;

            // Determine how much to scale down the image
            float scaleFactor =
                    Math.max(
                            (float) mSelectedImage.getWidth() / (float) targetWidth,
                            (float) mSelectedImage.getHeight() / (float) maxHeight);

            Bitmap resizedBitmap =
                    Bitmap.createScaledBitmap(
                            mSelectedImage,
                            (int) (mSelectedImage.getWidth() / scaleFactor),
                            (int) (mSelectedImage.getHeight() / scaleFactor),
                            true);

            mImageView.setImageBitmap(resizedBitmap);
            mSelectedImage = resizedBitmap;
        }
    }

    @Override
    public void onNothingSelected(AdapterView<?> parent) {
        // Do nothing
    }

    // Utility functions for loading and resizing images from app asset folder.
    public static Bitmap getBitmapFromAsset(Context context, String filePath) {
        AssetManager assetManager = context.getAssets();

        InputStream is;
        Bitmap bitmap = null;
        try {
            is = assetManager.open(filePath);
            bitmap = BitmapFactory.decodeStream(is);
        } catch (IOException e) {
            e.printStackTrace();
        }

        return bitmap;
    }

    // Returns max image width, always for portrait mode. Caller needs to swap width / height for
    // landscape mode.
    private Integer getImageMaxWidth() {
        if (mImageMaxWidth == null) {
            // Calculate the max width in portrait mode. This is done lazily since we need to
            // wait for a UI layout pass to get the right values. So delay it to first time image
            // rendering time.
            mImageMaxWidth = mImageView.getWidth();
        }

        return mImageMaxWidth;
    }

    // Returns max image height, always for portrait mode. Caller needs to swap width / height for
    // landscape mode.
    private Integer getImageMaxHeight() {
        if (mImageMaxHeight == null) {
            // Calculate the max width in portrait mode. This is done lazily since we need to
            // wait for a UI layout pass to get the right values. So delay it to first time image
            // rendering time.
            mImageMaxHeight =
                    mImageView.getHeight();
        }

        return mImageMaxHeight;
    }

    // Gets the targeted width / height.
    private Pair<Integer, Integer> getTargetedWidthHeight() {
        int targetWidth;
        int targetHeight;
        int maxWidthForPortraitMode = getImageMaxWidth();
        int maxHeightForPortraitMode = getImageMaxHeight();
        targetWidth = maxWidthForPortraitMode;
        targetHeight = maxHeightForPortraitMode;
        return new Pair<>(targetWidth, targetHeight);
    }
}

但我仍然在Failed to get input dimensions. 0-th input should have 268203 bytes, but found 1072812 bytes为 inception 和0-th input should have 150528 bytes, but found 602112 bytesmobilenet 做准备。所以,一个因素4总是存在的。

要查看我所做的更改,输出diff original.java changed.java是:(忽略行号)

32a33,34
>     private static final int IMAGE_MEAN = 128;
>     private static final float IMAGE_STD = 128.0f;
150,151c152,153
<                                     byte[][] labelProbArray = task.getResult()
<                                             .<byte[][]>getOutput(0);
---
>                                     float[][] labelProbArray = task.getResult()
>                                             .<float[][]>getOutput(0);
170c172
<     private synchronized List<String> getTopLabels(byte[][] labelProbArray) {
---
>     private synchronized List<String> getTopLabels(float[][] labelProbArray) {
173,174c175
<                     new AbstractMap.SimpleEntry<>(mLabelList.get(i), (labelProbArray[0][i] &
<                             0xff) / 255.0f));
---
>                     new AbstractMap.SimpleEntry<>(mLabelList.get(i), (labelProbArray[0][i] )));
214c215,216
<                         DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
---
>                        4*DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
> 
226,228c228,232
<                 imgData.put((byte) ((val >> 16) & 0xFF));
<                 imgData.put((byte) ((val >> 8) & 0xFF));
<                 imgData.put((byte) (val & 0xFF));
---
>                 imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
>                 imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
>                 imgData.putFloat(((val & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
4

1 回答 1

1

这是在代码实验室中分配缓冲区的方式:

ByteBuffer imgData = ByteBuffer.allocateDirect(
    DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);

DIM_BATCH_SIZE - 一个典型的用法是支持批处理(如果模型支持的话)。在我们的示例和您的测试中,您一次输入 1 张图像并将其保持为 1。 DIM_PIXEL_SIZE - 我们在代码实验室中设置了 3,对应于 r/g/b 每个 1 字节。

但是,看起来您使用的是浮动模型。然后,不是 r/g/b 各一个字节,而是使用一个浮点数(4 个字节)来表示 r/g/b 每个(您自己已经弄清楚了这部分)。那么您使用上述代码分配的缓冲区不再足够。

您可以在此处遵循浮动模型的示例: https ://github.com/tensorflow/tensorflow/blob/25b4086bb5ba1788ceb6032eda58348f6e20a71d/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo /ImageClassifierFloatInception.java

确切地说,imgData 人口的分配公式如下:

ByteBuffer imgData = ByteBuffer.allocateDirect(
    DIM_BATCH_SIZE * getImageSizeX() * getImageSizeY() * DIM_PIXEL_SIZE 
    * getNumBytesPerChannel());

getNumBytesPerChannel() 在您的情况下应该是 4 。


[关于以下错误的新问题的更新]:

无法获取输入尺寸。第 0 个输入应该有 268203 字节,但找到了 1072812 字节

这是检查模型预期的字节数 == 传入的字节数。 268203 = 299 * 299 * 3 & 1072812 = 4 * 299 * 299 * 3。看起来您使用的是量化模型,但喂给了它浮动模型的数据。你能仔细检查你使用的模型吗?为简单起见,不要指定云模型源并仅使用资产中的本地模型。


[更新 0628,开发人员说他们训练了一个浮动模型]:

可能是您的模型错误;也可能是您下载了覆盖本地模型的云模型。但是错误消息告诉我们正在加载的模型不是浮动模型。

为了隔离问题,我建议进行以下几个测试:1)从快速启动应用程序中删除 setCloudModelName / registerCloudModelSource 2)使用官方 TFLite 浮动模型您必须下载评论中提到的模型并更改 Camera2BasicFragment 以使用该 ImageClassifierFloatInception(而不是ImageClassifierQuantizedMobileNet) 3) 仍然使用相同的 TFLite 示例应用程序,切换到您自己训练的模型。确保将图像大小调整为您的值。

于 2018-06-26T17:54:17.120 回答