0

我有以下代码,它输出的值array1小于或等于array2. 这两个数组的长度不同。这个 for 循环非常慢,因为数组很大(~500,000元素)。仅供参考,两个数组始终按升序排列。

任何帮助使它成为矢量操作并加快它的速度将不胜感激。

interp1()我正在考虑使用“最近”选项的某种多步骤过程。然后找到对应outArray的位置大于array2然后以某种方式固定点......但我认为必须有更好的方法。

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];
outArray = nan(size(array2));
for a =1:numel(array2)
    outArray(a) = array1(find(array1 <= array2(a),1,'last'));
end

返回:

outArray =    
     5     5    15    24
4

3 回答 3

3

这是一种可能的矢量化:

[~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
outArray = array1(idx);

编辑:

在最近的版本中,由于 JIT 编译,MATLAB 在执行良好的旧非向量化循环方面已经相当出色。

下面是一些与您的代码类似的代码,它利用了两个数组已排序这一事实(因此,如果pos(a) = find(array1<=array2(a), 1, 'last')我们保证pos(a+1)在下一次迭代中计算的值不会少于上一次pos(a)

pos = 1;
idx = zeros(size(array2));
for a=1:numel(array2)
    while pos <= numel(array1) && array1(pos) <= array2(a)
        pos = pos + 1;
    end
    idx(a) = pos-1;
end
%idx(idx==0) = [];      %# in case min(array2) < min(array1)
outArray = array1(idx);

注意:注释行处理最小值array2小于最小值的情况array1(即为find(array1<=array2(a))空时)

我对迄今为止发布的所有方法进行了比较,这确实是最快的一种。长度为 N=5000 的向量的时序(使用TIMEIT函数执行)为:

0.097398     # your code
0.39127      # my first vectorized code
0.00043361   # my new code above
0.0016276    # Mohsen Nosratinia's code

这里是 N=500000 的时间:

(? too-long) # your code
(out-of-mem) # my first vectorized code
0.051197     # my new code above
0.25206      # Mohsen Nosratinia's code

.. 从您报告的最初 10 分钟缩短到 0.05 秒,这是一个相当不错的改进!

如果您想重现结果,这里是测试代码:

function [t,v] = test_array_find()
    %array2 = [5 6 18 25];
    %array1 = [1 5 9 15 22 24 31];
    N = 5000;
    array1 = sort(randi([100 1e6], [1 N]));
    array2 = sort(randi([min(array1) 1e6], [1 N]));

    f = {...
        @() func1(array1,array2);   %# Aero Engy
        @() func2(array1,array2);   %# Amro
        @() func3(array1,array2);   %# Amro
        @() func4(array1,array2);   %# Mohsen Nosratinia
    };

    t = cellfun(@timeit, f);
    v = cellfun(@feval, f, 'UniformOutput',false);
    assert( isequal(v{:}) )
end

function outArray = func1(array1,array2)
    %idx = arrayfun(@(a) find(array1<=a, 1, 'last'), array2);
    idx = zeros(size(array2));
    for a=1:numel(array2)
        idx(a) = find(array1 <= array2(a), 1, 'last');
    end
    outArray = array1(idx);
end

function outArray = func2(array1,array2)
    [~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
    outArray = array1(idx);
end

function outArray = func3(array1,array2)
    pos = 1;
    lastPos = numel(array1);
    idx = zeros(size(array2));
    for a=1:numel(array2)
        while pos <= lastPos && array1(pos) <= array2(a)
            pos = pos + 1;
        end
        idx(a) = pos-1;
    end
    %idx(idx==0) = [];      %# in case min(array2) < min(array1)
    outArray = array1(idx);
end

function outArray = func4(array1,array2)
    [~,I] = sort([array1 array2]);
    a1size = numel(array1);
    J = find(I>a1size);
    outArray = nan(size(array2));
    for k=1:numel(J),
        if  I(J(k)-1)<=a1size,
            outArray(k) = array1(I(J(k)-1));
        else
            outArray(k) = outArray(k-1);
        end
    end
end
于 2013-07-03T17:58:34.800 回答
2

它缓慢的一个原因是您将所有元素与 in 中array1的所有元素进行比较,array2因此如果它们分别包含MN元素,则复杂性为O(M*N). 但是,由于数组已经排序,因此有一个线性时间O(M+N), 解决方案

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];

outArray = nan(size(array2));
k1 = 1;
n1 = numel(array1);
n2 = numel(array2);

ks = 1;
while ks <= n2 && array2(ks) < array1(1)
    ks = ks + 1;
end

for k2=ks:n2
    while k1 < n1 && array2(k2) >= array1(k1+1) 
        k1 = k1+1;
    end
    outArray(k2) = array1(k1);
end

这是一个测试用例,用于测量每种方法运行两个长度为 500,000 的数组所需的时间。

array2 = 1:500000;
array1 = array2-1;

tic
outArray1 = nan(size(array2));
k1 = 1;
n1 = numel(array1);
n2 = numel(array2);

ks = 1;
while ks <= n2 && array2(ks) < array1(1)
    ks = ks + 1;
end

for k2=ks:n2
    while k1 < n1 && array2(k2) >= array1(k1+1) 
        k1 = k1+1;
    end
    outArray1(k2) = array1(k1);
end
toc    

tic
outArray2 = nan(size(array2));
for a =1:numel(array2)
    outArray2(a) = array1(find(array1 <= array2(a),1,'last'));
end
toc

结果是

Elapsed time is 0.067637 seconds.
Elapsed time is 418.458722 seconds.
于 2013-07-03T23:45:03.017 回答
0

注意: 这是我最初的解决方案,也是在 Amro 的回答中作为基准的解决方案。但是,它比我在其他答案中提供的线性时间解决方案要慢。

它缓慢的一个原因是您将所有元素与 in 中array1的所有元素进行比较,array2因此如果它们包含MN元素,则复杂性为O(M*N). 但是,您可以将它们连接起来并将它们排序在一起,并获得更快的复杂性算法(M+N)*log2(M+N)。这是一种方法:

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];

[~,I] = sort([array1 array2]);
a1size = numel(array1);
J = find(I>a1size);
outArray = nan(size(array2));
for k=1:numel(J),
    if  I(J(k)-1)<=a1size,
        outArray(k) = array1(I(J(k)-1));
    else
        outArray(k) = outArray(k-1);
    end
end

disp(outArray)

% Test using original code
outArray = nan(size(array2));
for a =1:numel(array2)
    outArray(a) = array1(find(array1 <= array2(a),1,'last'));
end
disp(outArray)

连接的数组将是

>> [array1 array2]
ans =
     1     5     9    15    22    24    31     5     6    18    25

>> [B,I] = sort([array1 array2])
B =
     1     5     5     6     9    15    18    22    24    25    31
I =
     1     2     8     9     3     4    10     5     6    11     7

它表明在排序数组B中,第一个5来自连接数组中的第二个位置,第二个 5 来自八个位置,依此类推。因此,要找到其中array1小于给定元素的最大元素,array2我们只需要遍历所有I大于array1(因此属于array2)大小的索引,然后返回并找到属于 的最近索引array1J包含这些元素在 vector 中的位置I

>> J = find(I>a1size)
J =
     3     4     7    10

现在 for 循环遍历这些索引并检查索引是否在I引用的每个索引之前的索引Jarray1。如果它属于array1它,则从中检索它的值,array1否则它复制为先前索引找到的值。

请注意,如果您的代码和此代码array2包含的元素小于array1.

于 2013-07-03T19:16:53.550 回答