我正在尝试编写一个程序来识别目录中的哪些图像与查询图像相似,查询图像与目录中的图像相似但通常略有不同。目录中有数千张图片。这个问题与比较图像相似度的简单快速方法有关。
我有几个目标:
- 使用查询图像,识别图像目录中的相似图像
- 查询图像可能与目录中的图像略有不同。这些更改可能包括被裁剪的图像和不同的图像质量。
- 该程序应该非常快(最多能够在几秒钟内识别出相似的图像)
我知道这是一个需要大量研究的问题。“云、移动和边缘的实用深度学习”中的“构建反向图像搜索引擎:理解嵌入”一章解释了这个问题的一些方法。
我开始使用 SIFT(尺度不变特征变换)+词袋方法编写一个程序来做到这一点。我在这方面没有太多经验。我编写的程序适用于相同的图像,也适用于稍微相似的图像,但是一旦图像变得更加不同,它就不再检测到正确的图像。
我有两个问题:
- 我使用的方法是否很好,如果没有,是否有更好的方法?
- 我的程序中是否有任何内容可能导致搜索对不同图像不准确?
这是程序的工作方式:
- 遍历每个图像,使用 SIFT 获取其描述符,并构建这些描述符的列表。
- 使用 k-means,找到描述符列表的质心。这就是“字典”。
- 再次遍历每个图像,并为每个图像的描述符和质心获取 k 近邻 knnMatch,其中 k=1。使用每个匹配项为每个图像创建一个直方图,使用 match.trainIdx。
- 通过将每个“单词”的计数除以“单词”的总和来标准化每个图像的直方图。
- 使用 knnMatch 和 k=1 与查询图像的描述符和质心。浏览匹配项并创建标准化直方图。
- 对查询图像的直方图以及数据库中所有图像的直方图使用 knnMatch,k=1。这将创建一个匹配列表,按与查询图像的相似性排序。
import numpy as np
import cv2
import os
from matplotlib import pyplot as plt
sift = cv2.xfeatures2d.SIFT_create()
FLANN_INDEX_KDTREE = 0
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 100)
search_params = dict(checks = 100)
flann = cv2.FlannBasedMatcher(index_params, search_params)
bf = cv2.BFMatcher()
img1 = cv2.imread('path',0)
db = # load database
kp1, des1 = sift.detectAndCompute(img1,None)
load = False
clusters = 800
if load:
db.query('DELETE FROM centroids')
db.query('DELETE FROM histogram')
descriptors = []
for file in os.listdir('path'):
if file.endswith('.png'):
img = cv2.imread('path/{}'.format(file), 0)
kp, des = sift.detectAndCompute(img,None)
if des is None:
continue
descriptors.extend(des)
descriptors = np.float32(descriptors)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 5, .01)
centroids = cv2.kmeans(descriptors, clusters, None, criteria, 1, cv2.KMEANS_PP_CENTERS)[2]
db.insert('centroids', d = np.ndarray.dumps(centroids))
for file in os.listdir('path'):
counter = np.zeros((clusters,), dtype=np.uint32)
if file.endswith('.png'):
img = cv2.imread('path/{}'.format(file),0)
kp, d = sift.detectAndCompute(img,None)
if d is None:
continue
matches = bf.knnMatch(d, centroids, k=1)
for match in matches:
counter[match[0].trainIdx] += 1
counter_sum = np.sum(counter)
counter = [float(n)/counter_sum for n in counter]
db.insert('histogram', frame_id = file, count=','.join(np.char.mod('%f', counter)))
histograms_db = list(db.query('SELECT * FROM histogram'))
histograms = []
for histogram in histograms_db:
histogram = histogram['count'].split(',')
histograms.append(histogram)
histograms = np.array(histograms)
counter = np.zeros((clusters,), dtype=np.uint32)
centroids = np.loads(db.query('SELECT * FROM centroids')[0]['d'])
matches = bf.knnMatch(des1, centroids, k=1)
for match in matches:
counter[match[0].trainIdx] += 1
counter_sum = np.sum(counter)
counter = [float(n)/counter_sum for n in counter]
matches = bf.knnMatch(np.float32([counter]), np.float32(histograms), k=1)
for match in matches[0]:
print "{} {}".format(histograms_db[match.trainIdx]['frame_id'], match.distance)
name = histograms_db[match.trainIdx]['frame_id']