2

我有以下 ODE 耦合系统(来自离散积分微分 PDE): 在此处输入图像描述

xi 是我控制的 x 网格上的点。我可以用以下简单的代码解决这个问题:

using DifferentialEquations

function ode_syst(du,u,p, t)
    N = Int64(p[1])
    beta= p[2]
    deltax = 1/(N+1)
    xs = [deltax*i for i in 1:N]
    for j in 1:N
        du[j] = -xs[j]^(beta)*u[j]+deltax*sum([u[i]*xs[i]^(beta) for i in 1:N])
    end
end

N = 1000
u0 = ones(N)
beta = 2.0
p = [N, beta]
tspan = (0.0, 10^3);

prob = ODEProblem(ode_syst,u0,tspan,p);
sol = solve(prob);

然而,当我使我的网格更精细,即增加 N 时,计算时间会迅速增长(我猜缩放是 N 的二次方)。关于如何使用分布式并行或多线程来实现这一点有什么建议吗?

附加信息: 我附上了可能有助于了解程序大部分时间花费在哪里的分析图在此处输入图像描述

4

2 回答 2

4

我查看了您的代码,发现了一些问题,例如由于重新计算总和项而意外引入的 O(N^2) 行为。

我的改进版本使用 Tullio 包来进一步加快矢量化的速度。Tullio 还具有可调整的参数,如果您的系统变得足够大,这些参数将允许多线程。请参阅此处,您可以在选项部分调整哪些参数。您可能还会在那里看到 GPU 支持,我没有对此进行测试,但它可能会进一步加速或严重中断。我还选择从实际数组中获取长度,这应该使使用更经济且不易出错。

using Tullio

function ode_syst_t(du,u,p, t)
    N = length(du)
    beta = p[1]
    deltax = 1/(N+1)
    @tullio s := deltax*(u[i]*(i*deltax)^(beta))
    @tullio du[j] = -(j*deltax)^(beta)*u[j] + s
    return nothing
end

你的代码:

 @btime sol = solve(prob);
  80.592 s (1349001 allocations: 10.22 GiB)

我的代码:

prob2 = ODEProblem(ode_syst_t,u0,tspan,[2.0]);
@btime sol2 = solve(prob2);
  1.171 s (696 allocations: 18.50 MiB)

结果或多或少同意:

julia> sum(abs2, sol2(1000.0) .- sol(1000.0))
1.079046922815598e-14

我还对 Lutz Lehmanns 解决方案进行了基准测试:

prob3 = ODEProblem(ode_syst_lehm,u0,tspan,p);

@btime sol3 = solve(prob3);
  1.338 s (3348 allocations: 39.38 MiB)

然而,当我们用 (0.0, 10.0) 的 tspan 将 N 缩放到 1000000

prob2 = ODEProblem(ode_syst_t,u0,tspan,[2.0]);

@time solve(prob2);
  2.429239 seconds (280 allocations: 869.768 MiB, 13.60% gc time)

prob3 = ODEProblem(ode_syst_lehm,u0,tspan,p);

@time solve(prob3);
  5.961889 seconds (580 allocations: 1.967 GiB, 11.08% gc time)

由于在我的旧机器上使用了 2 个内核,我的代码变得快了两倍多。

于 2021-09-29T18:52:29.673 回答
2

分析公式。显然,原子术语重复了。所以它们应该只计算一次。

function ode_syst(du,u,p, t)
    N = Int64(p[1])
    beta= p[2]
    deltax = 1/(N+1)
    xs = [deltax*i for i in 1:N]
    term = [ xs[i]^(beta)*u[i] for i in 1:N]
    term_sum = deltax*sum(term)
    for j in 1:N
        du[j] = -term[j]+term_sum
    end
end

这应该只线性增加N

于 2021-09-29T18:43:00.787 回答