在需要从大型方阵中删除非对角线元素的应用程序中,我遇到了一个小的性能瓶颈。所以,矩阵x
17 24 1 8 15
23 5 7 14 16
4 6 13 20 22
10 12 19 21 3
11 18 25 2 9
变成
17 0 0 0 0
0 5 0 0 0
0 0 13 0 0
0 0 0 21 0
0 0 0 0 9
问题:下面的 bsxfun 和 diag 解决方案是迄今为止最快的解决方案,我怀疑我是否可以改进它同时仍将代码保留在 Matlab 中,但是有更快的方法吗?
解决方案
这是我到目前为止的想法。
通过单位矩阵执行逐元素乘法。这是最简单的解决方案:
y = x .* eye(n);
使用bsxfun
和diag
:
y = bsxfun(@times, diag(x), eye(n));
下/上三角矩阵:
y = x - tril(x, -1) - triu(x, 1);
使用循环的各种解决方案:
y = x;
for ix=1:n
for jx=1:n
if ix ~= jx
y(ix, jx) = 0;
end
end
end
和
y = x;
for ix=1:n
for jx=1:ix-1
y(ix, jx) = 0;
end
for jx=ix+1:n
y(ix, jx) = 0;
end
end
定时
bsxfun
解决方案实际上是最快的。这是我的计时码:
function timing()
clear all
n = 5000;
x = rand(n, n);
f1 = @() tf1(x, n);
f2 = @() tf2(x, n);
f3 = @() tf3(x);
f4 = @() tf4(x, n);
f5 = @() tf5(x, n);
t1 = timeit(f1);
t2 = timeit(f2);
t3 = timeit(f3);
t4 = timeit(f4);
t5 = timeit(f5);
fprintf('t1: %f s\n', t1)
fprintf('t2: %f s\n', t2)
fprintf('t3: %f s\n', t3)
fprintf('t4: %f s\n', t4)
fprintf('t5: %f s\n', t5)
end
function y = tf1(x, n)
y = x .* eye(n);
end
function y = tf2(x, n)
y = bsxfun(@times, diag(x), eye(n));
end
function y = tf3(x)
y = x - tril(x, -1) - triu(x, 1);
end
function y = tf4(x, n)
y = x;
for ix=1:n
for jx=1:n
if ix ~= jx
y(ix, jx) = 0;
end
end
end
end
function y = tf5(x, n)
y = x;
for ix=1:n
for jx=1:ix-1
y(ix, jx) = 0;
end
for jx=ix+1:n
y(ix, jx) = 0;
end
end
end
返回
t1: 0.111117 s
t2: 0.078692 s
t3: 0.219582 s
t4: 1.183389 s
t5: 1.198795 s