编辑:我列出了该问题的三个实现。
首先,可以完全消除循环,但是生成的函数avg_prec_noloop()非常消耗内存,因为它试图一次完成所有操作。只要项目的数量在 100 以内,它总是会很快工作。不幸的是,当项目数趋于 1000 或更多时,它会消耗过多的内存,并且会导致崩溃。我包括这个只是为了表明它可以在没有循环的情况下完成,但我不建议使用它。
遵循与原始类似的逻辑,但通过在项目上添加单个循环,我们有函数avg_prec_colwise. 我们可以通过一次获取整个阈值列来计算所有用户的精度和召回@K。它与之前的无循环实现有相似的时间,但它并不像内存消耗那么大,并且仍然具有当 items<=100 时它的速度相当快的特性,无论用户数量如何。对于 100,000 个用户和 10 个项目,它的速度是原来的近 300 倍;但是如果 items>=1000,它会比原来慢一百倍。每当您有大量用户和少量项目的场景时,我建议您使用它。
最后,我有一个avg_prec_rowwise可能最接近 sklearn 的实现。当项目较少时,它没有 colwise 或 noloop 函数的惊人增益,但无论项目或用户数量如何,它始终比使用原始快 10-20%。出于一般目的,我建议您使用这个。
import numpy as np
from sklearn.metrics import average_precision_score as aps
from sklearn.metrics import precision_recall_curve as prc
import warnings
warnings.filterwarnings('ignore')
def mean_aps(true_scores, predicted_scores):
'''Mean Average Precision Score'''
return np.mean([aps(t, p) for t, p in zip(true_scores, predicted_scores) if t.sum() > 0])
def avg_prec_noloop(yt, yp):
valid = yt.sum(axis=1) != 0
yt, yp = yt[valid], yp[valid]
THRESH = np.sort(yp).T
yp = yp.reshape(1, yp.shape[0], yp.shape[1]) >= THRESH.reshape(THRESH.shape[0], THRESH.shape[1], 1)
a = (yt*(yt==yp)).sum(axis=2)
b = yp.sum(axis=2)
c = yt.sum(axis=1)
p = (np.where(b==0,0,a/b))
r = a/c
rdif = np.vstack((r[:-1]-r[1:],r[-1]))
return (rdif*p).sum()/yt.shape[0]
def avg_prec_colwise(yt, yp):
valid = yt.sum(axis=1) != 0
yt, yp = yt[valid], yp[valid]
N_USER, N_ITEM = yt.shape
THRESH = np.sort(yp)
p, r = np.zeros((N_USER, N_ITEM)), np.zeros((N_USER, N_ITEM))
c = yt.sum(axis=1)
for i in range(N_ITEM):
ypt = yp >= THRESH[:,i].reshape(-1,1)
a = (yt*(yt==ypt)).sum(axis=1)
b = ypt.sum(axis=1)
p[:,i] = np.where(b==0,0,a/b).reshape(-1)
r[:,i] = a/c
rdif = np.hstack((r[:,:-1]-r[:,1:],r[:,-1].reshape(-1,1)))
return (rdif*p).sum()/N_USER
def avg_prec_rowwise(yt, yp):
valid = yt.sum(axis=1) != 0
yt, yp = yt[valid], yp[valid]
N_USER, N_ITEM = yt.shape
p, r = np.zeros((N_USER, N_ITEM)), np.zeros((N_USER, N_ITEM))
for i in range(N_USER):
a, b, _ = prc(yt[i,:], yp[i,:])
p[i,:len(a)-1] = a[:-1]
r[i,:len(b)-1] = b[:-1]
rdif = np.hstack((r[:,:-1]-r[:,1:],r[:,-1].reshape(-1,1)))
return (rdif*p).sum()/N_USER
一些时间场景: 1)真正的项目少
N_USERS = 10000
N_ITEMS = 10
a = np.random.choice(2,(N_USERS, N_ITEMS))
b = np.random.random(size=(N_USERS, N_ITEMS))
start = time.time()
for i in range(10):
mean_aps(a,b)
end = time.time()
print('Original:',end-start)
start = time.time()
for i in range(10):
avg_prec_colwise(a,b)
end = time.time()
print('Colwise:',end-start)
start = time.time()
for i in range(10):
avg_prec_rowwise(a,b)
end = time.time()
print('Rowwise:',end-start)
出去:
Original: 47.91176509857178
Colwise: 0.16370844841003418
Rowwise: 37.96852993965149
2)更多项目:
N_USERS = 3000
N_ITEMS = 100
a = np.random.choice(2,(N_USERS, N_ITEMS))
b = np.random.random(size=(N_USERS, N_ITEMS))
start = time.time()
for i in range(10):
mean_aps(a,b)
end = time.time()
print('Original:',end-start)
start = time.time()
for i in range(10):
avg_prec_colwise(a,b)
end = time.time()
print('Colwise:',end-start)
start = time.time()
for i in range(10):
avg_prec_rowwise(a,b)
end = time.time()
print('Rowwise:',end-start)
出去:
Original: 14.943019151687622
Colwise: 2.0997579097747803
Rowwise: 11.798128604888916
3)物品数量:
N_USERS = 3000
N_ITEMS = 1000
a = np.random.choice(2,(N_USERS, N_ITEMS))
b = np.random.random(size=(N_USERS, N_ITEMS))
start = time.time()
for i in range(10):
mean_aps(a,b)
end = time.time()
print('Original:',end-start)
start = time.time()
for i in range(10):
avg_prec_colwise(a,b)
end = time.time()
print('Colwise:',end-start)
start = time.time()
for i in range(10):
avg_prec_rowwise(a,b)
end = time.time()
print('Rowwise:',end-start)
出去:
Original: 20.760642051696777
Colwise: 248.5634708404541
Rowwise: 17.940539121627808
4)原始和逐行之间的最后比较,没有任何循环:
N_USERS = 10000
N_ITEMS = 1000
a = np.random.choice(2,(N_USERS, N_ITEMS))
b = np.random.random(size=(N_USERS, N_ITEMS))
start = time.time()
mean_aps(a,b)
end = time.time()
print('Original:',end-start)
start = time.time()
avg_prec_rowwise(a,b)
end = time.time()
print('Rowwise:',end-start)
出去:
Original: 6.912739515304565
Rowwise: 5.9845476150512695