1

因此,我正在尝试使用 Keras 创建一个 Dog vs. Cat 图像分类模型。我的部分目标是创建一个使用 Tensorflow.js 部署模型的网站。我已经成功部署了使用 Flask 作为服务器的模型。

主要问题是模型是 Tensorflow.js,其性能比普通 keras 中的模型差得多。当使用普通的 keras 时,我的模型在测试数据上的准确率达到了 90% 左右。但是,当在 tensorflow.js 中使用时,该模型没有得到任何正确的测试图像。我将不胜感激有关解决此问题的任何帮助或任何提示。

模板/index.html

<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width">
    <title>repl.it</title>

    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>

    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
    <link href="{{ url_for('static', filename='index.css') }}" rel="stylesheet" type="text/css" />
    <link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
  </head>
  <body onload="$('#result').hide();$('#continue').hide();">
    <div class="container-fluid">
      <!-- START HEADER -->
      <div class="row" id="headerRow">
        <div class="col-md d-flex align-items-center" id="headerColumn">
          <h2>Cat<span class='or'>or</span>Dog</h2>
        </div>
      </div>
      <!-- END HEADER -->

      <!-- START BODY -->
      <div class="row bodyRow" id='bodyRow'>
        <div class="col-md d-flex align-items-center bodyColumn">
          <div class="body">
            <form class="d-flex align-items-center  justify-content-center imageSubmitForm" method="POST" enctype="multipart/form-data">
              <label class="d-flex align-items-center justify-content-center" for='imageInputField'>
                <i class="material-icons">file_upload</i>

                <p id='result'></p>
                <br/>
                <p id='continue'>Press Anywhere to continue...</p>
              </label>
              <input class="imageInputField" id='imageInputField' type='file' onchange='getPrediction(url)'/>
            </form>
          </div>
        </div>
      </div>
      <!-- END BODY -->

      <!-- START RESULT -->
      <div class="row resultRow">
        <div class="col-md-6 classResultColumn">
            <div class="d-flex align-items-center justify-content-center classResultBox">
                <p id='classResult'></p>
            </div>
        </div>
        <div class="col-md-6 scoreResultColumn">
            <div class="d-flex align-items-center justify-content-center scoreResultBox">
                <p id='scoreResult'></p>
            </div>
        </div>
      </div>
      <!-- END RESULT -->

      <!-- START FOOTER -->
      <!--
      <div class="row d-flex align-items-center footerRow" id='footerRow'>
        <center><a src="#">Source Code</a></center>
      </div>
      -->
      <!-- FOOTER -->
    </div>

    <!-- START SCRIPTS -->
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js"></script>
    <script src="https://code.jquery.com/jquery-3.4.1.min.js"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
    <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>
    <script src="{{ url_for('static', filename='index.js') }}"></script>   
    <!-- END SCRIPTS -->
  </body>
</html>

静态/index.js

let fileInput = document.getElementById("imageInputField");
let classResultElement = document.getElementById("classResult");
let scoreResultElement = document.getElementById("scoreResult");
let url = "/model";

let model;
let file;
let data;
let responseContent;
let features;
let predictedClass;

let getPrediction = async(url) => {
    if (!model)
        model = await tf.loadLayersModel(url);

    file = fileInput.files[0];
    data = new FormData();
    data.append("file", file);

    $.ajax({
        url : "/api/preprocess",
        type: 'POST',
        data: data,
        traditional: true,
        processData: false,
        contentType: false,

        success: function(response)
        {
            responseContent = JSON.parse(response)['image'];

            if (responseContent != "False")
            {
                features = tf.tensor(responseContent);
                score = model.predict(features).dataSync();

                alert(score);

                if (score >= 0.5) {
                    predictedClass = "Dog";

                    classResultElement.innerHTML = "<b>Predicted Class:</b> " + predictedClass;
                    scoreResultElement.innerHTML = "<b>Certainty:</b> " + score*100.0 + "%";
                } else {
                    predictedClass = "Cat";

                    classResultElement.innerHTML = "<b>Predicted Class:</b>" + predictedClass;
                    scoreResultElement.innerHTML = "<b>Certainty:</b> " + (1.0 - score) * 100.0 + "%";
                }

                alert(predictedClass);
            }
        }
    });
}

应用程序.py

import flask
from flask_cors import CORS
from werkzeug import secure_filename
import time
import os
import keras
import numpy as np
import json
import matplotlib.pyplot as plt

app = flask.Flask(__name__)
CORS(app)

UPLOADS_DIR = "uploads/"

@app.route("/")
def index():
  """
    Fetch and return the main homepage. 
  """
  return flask.render_template("index.html")

@app.route("/favicon.ico")
def get_favicon():
  """
    Return a fake message in order to silence the error caused by a favicon not being found.
  """
  return "Favicon Does Not Exist"

@app.route("/model")
def get_modeljson():
  """
    Get the model.json file and return it's contents.
  """
  with open("model/model.json", "r") as f:
    return f.read()

@app.route("/<path:path>")
def get_shard(path):
  """
    get the binary weight file for the model (also known as a shard).

    path    =>    the filename of the binary weight file.
  """
  return flask.send_from_directory("model/", path)

@app.route("/api/preprocess", methods=['POST'])
def preprocess():
  """
    takes an image object from an AJAX request and returns a normalized list of the values.
  """
  if flask.request.method == 'POST':
    file = flask.request.files['file']
    filename = secure_filename(file.filename)
    new_filename = "{}_{}".format(time.time(), filename)
    file.save(os.path.join(UPLOADS_DIR, new_filename))

    img_obj = keras.preprocessing.image.load_img(os.path.join(UPLOADS_DIR, new_filename), target_size=(224, 224))
    img_arr = keras.preprocessing.image.img_to_array(img_obj).reshape(1, 224, 224, 3)
    img_arr = np.divide(img_arr, 255.)

    os.remove(os.path.join(UPLOADS_DIR, new_filename))
    return json.dumps({"image":img_arr.tolist()})
  return json.dumps({"image":"False"})

if __name__ == "__main__":
  app.run()

您可以在此处找到用于训练模型的 kaggle 笔记本的 URL 。您可以在此处找到用于测试代码的笔记本。

非常感谢任何帮助或提示。

4

1 回答 1

0

在喝了一吨咖啡,几乎没有睡觉之后,我找到了一个解决方案。显然,WebGL 的内部结构与 Python 中的 Tensorflow 内部结构不同。这里的解决方法是禁用 WebGL。

在加载图形模型之前,添加...

tf.ENV.set("WEBGL_PACK", false);

这会禁用 WebGL 并强制 TFJS 更像 python!

于 2019-10-13T14:46:54.767 回答