尝试使用 yolov4 创建 webapp,已将 yolov4 权重转换为 tensorflow 权重,当我运行app.py文件时,它显示错误AttributeError: 'InteractiveSession' object has no attribute 'tiny'。
使用 2 个文件app.py和app_helper.py。
app_helper.py
import tensorflow as tf
import cv2
from PIL import Image
import numpy as np
from core.yolov4 import YOLOv4,YOLOv4_tiny,decode,decode_tf,filter_boxes
from core.utils import load_weights,read_class_names,image_preprocess,draw_bbox,load_config
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
def get_images(image_path,image_name):
class_path='./data/classes/coco.names'
weights_path='./checkpoints/yolov4-416'
NUM_CLASS=80
tiny=False
size=416
output='/data/detections'
framework='tf'
image='./data/images/kite.jpg'
model='yolov4'
iou=0.45
score=0.25
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = load_config(session) #utils.
input_size = size
image_path = image
original_image = cv2.imread(image_path)
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
# image_data = utils.image_preprocess(np.copy(original_image), [input_size, input_size])
image_data = cv2.resize(original_image, (input_size, input_size))
image_data = image_data / 255.
# image_data = image_data[np.newaxis, ...].astype(np.float32)
images_data = []
for i in range(1):
images_data.append(image_data)
images_data = np.asarray(images_data).astype(np.float32)
if framework == 'tflite':
interpreter = tf.lite.Interpreter(model_path=weights_path) #replaced FLAGS.weights
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
interpreter.set_tensor(input_details[0]['index'], images_data)
interpreter.invoke()
pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
if model == 'yolov3' and tiny == True:
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25,
input_shape=tf.constant([input_size, input_size]))
else:
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25,
input_shape=tf.constant([input_size, input_size]))
else:
saved_model_loaded = tf.saved_model.load(weights_path, tags=[tag_constants.SERVING]) #weights #
infer = saved_model_loaded.signatures['serving_default']
batch_data = tf.constant(images_data)
pred_bbox = infer(batch_data)
for key, value in pred_bbox.items():
boxes = value[:, :, 0:4]
pred_conf = value[:, :, 4:]
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
scores=tf.reshape(
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
max_output_size_per_class=50,
max_total_size=50,
iou_threshold=iou,
score_threshold=score
)
pred_bbox = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
image = draw_bbox(original_image, pred_bbox) #utils.
# image = utils.draw_bbox(image_data*255, pred_bbox)
image = Image.fromarray(image.astype(np.uint8))
image.show()
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
cv2.imwrite(output, image)
和app.py
from flask import Flask, render_template, Response, request, session, redirect, url_for, send_from_directory,flash,jsonify
from werkzeug.utils import secure_filename
from PIL import Image
import os
import sys
import cv2
from app_helper import *
from flask_cors import CORS,cross_origin
app=Flask(__name__)
upload_folder='./data/images'
app.config['upload_folder'] = upload_folder
@app.route("/")
def index():
return render_template("index.html")
@app.route("/about")
def about():
return render_template("about.html")
@app.route('/uploader', methods=['GET', 'POST'])
def upload_file():
if request.method == 'POST':
f = request.files['file']
# create a secure filename
filename = secure_filename(f.filename)
print(filename)
# save file to data/images #/static/uploads
filepath = os.path.join(app.config['upload_folder'], filename)
print(filepath)
f.save(filepath)
get_images(filepath,filename)
return render_template("uploaded.html", display_detection=filepath, fname=filepath)
if __name__ == '__main__':
app.run(port=7000, debug=True)
请帮助我了解如何克服此错误