1

搜索论坛后,我没有找到类似的问题。如果你找到了,请告诉我。我真的很感激。

我需要从截断的伽马分布中生成 1000 个样本点均值,在 R 中具有 1000 个不同的形状和比例值。

我的以下代码有效,但速度很慢。如何提高性能?

library(distr)
library(distrEx)
library(truncdist)
set.seed(RANDOM.SEED)
shape.list <- runif(1000, max = 10, min = 0.01)
scale.list <- runif(1000, max = 100000, min = 100000)
mean.list <- list()
std.dev.list <- list()
for (i in seq(1000)) # very slow
{
  sample.points <- rtrunc(100000, spec="gamma", a = lb.arg, b = ub.arg, 
                         shape = shape.list[[i]], scale = scale.list[[i]])
  sample.mean <- mean(sample.points)
  mean.list <- append(mean.list, sample.mean)
  sample.std.dev <- sd(sample.points)
  std.dev.list <- append(std.dev.list, sample.std.dev)
}

for 循环非常慢并且需要很长时间。

任何更好的解决方案将不胜感激。谢谢 !

4

2 回答 2

2

这里发生了几件事。

首先,您将比例计算为:

scale.list <- runif(1000, max = 100000, min = 100000)

但由于min = max,所有值都是相同的。

其次,您没有指定lb.argor ub.arg,所以我将它们设置为 20 和 50 任意。

第三,使用分析此代码Rprof表明 > 90% 的时间花在由 .qtrunc(...)调用的函数中rtrunc(...)。这是因为您在每次迭代中生成 100,000 个样本点,并且qtrunc(...)必须对它们进行排序。总运行时间为 O(n),其中 n 是采样点的数量。在我的系统上,使用 n=1000 大约需要 7 秒,因此使用 n=100,000 需要 700 秒或大约 12 分钟。

我的建议是尝试更小的 n ,看看这是否真的有所作为。从中心极限定理我们知道,对于大的 n,均值的分布是渐近正态的,而与基础分布无关。我怀疑将 n 从 1000 增加到 100,000 会显着改变。

最后,在 R 中执行此操作的惯用方法是 [使用 n=1000]:

f <- function(shape,scale,lb,ub){
  X <- rtrunc(1000, spec="gamma", a=lb, b=ub,shape=shape,scale=scale)
  return(c(mean=mean(X),sd=sd(X)))
}
# set.seed(1)  # use this for a reproducible sample
result <- mapply(f,shape.list,scale.list, lb=20,ub=50)
result.apply <- data.frame(t(result))

它产生一个包含两列的数据框:每个形状/比例的平均值和标准差。通过在运行之前将种子设置为固定值mapply(...),并在运行 for 循环之前执行相同的操作,您可以证明它们都产生相同的结果。

于 2014-04-19T16:18:30.327 回答
2

不幸的是,没有办法优化您的任务。从截断分布生成随机点肯定有一些可能的优化......但事情是这样的:从随机分布生成 10^8 点左右会非常慢。

以下是我尝试的一些优化,它们加快了进程:

  1. 一次从 [a,b] 中的均匀分布生成所有随机数

  2. 回到截断分布定义的源头,而不依赖于“花哨”包(distr、distEx、truncdist)

  3. 编译我的代码以加快速度

代码:

# your original code, in a function
func = function()
{
  library(distr)
  library(distrEx)
  library(truncdist)
  set.seed(42)
  shape.list <- runif(1000, max = 10, min = 0.01)
  scale.list <- runif(1000, max = 100000, min = 100000)
  mean.list <- list()
  std.dev.list <- list()
  ITE.NUMBER = 10
  POINTS.NUMBER = 100000
  A = 0.25
  B = 0.5
  for (i in seq(ITE.NUMBER)) # very slow
  {
    sample.points <- rtrunc(POINTS.NUMBER, spec="gamma", a = A, b = B, 
                            shape = shape.list[[i]], scale = scale.list[[i]])
    sample.mean <- mean(sample.points)
    mean.list <- append(mean.list, sample.mean)
    sample.std.dev <- sd(sample.points)
    std.dev.list <- append(std.dev.list, sample.std.dev)
  }
}


# custom code
func2 = function()
{
  set.seed(42)
  shape.list <- runif(1000, max = 10, min = 0.01)
  scale.list <- runif(1000, max = 100000, min = 100000)
  mean.list <- list()
  std.dev.list <- list()
  ITE.NUMBER = 10
  POINTS.NUMBER = 100000
  A=0.25
  B=0.5
  # 
  # we generate all the random number at once, outside the loop
  #
  r <- runif(POINTS.NUMBER*ITE.NUMBER, min = 0, max = 1)
  for (i in seq(ITE.NUMBER)) # still very slow
  {
    #
    # back to the definition of the truncated gamma
    #
    sample.points <- qgamma(pgamma(A,  shape = shape.list[[i]], scale = scale.list[[i]]) +
                            r[(1+POINTS.NUMBER*(ITE.NUMBER-1)):(POINTS.NUMBER*(ITE.NUMBER))] *
                              (pgamma(B,  shape = shape.list[[i]], scale = scale.list[[i]]) - 
                                 pgamma(A,  shape = shape.list[[i]], scale = scale.list[[i]])),
                            shape = shape.list[[i]], scale = scale.list[[i]])
    sample.mean <- mean(sample.points)
    mean.list <- append(mean.list, sample.mean)
    sample.std.dev <- sd(sample.points)
    std.dev.list <- append(std.dev.list, sample.std.dev)
  }
}

#
# maybe a compilation would help?   
#  
require(compiler)
func2_compiled <- cmpfun(func2)

require(microbenchmark)

microbenchmark(func2(), func2_compiled(), func(),  times=10)

这给出了以下内容:

Unit: seconds
             expr      min       lq   median       uq      max neval
          func2() 1.462768 1.465561 1.475692 1.489235 1.532693    10
 func2_compiled() 1.403956 1.477983 1.487945 1.499133 1.515504    10
           func() 1.457553 1.478829 1.502671 1.510276 1.513486    10

结论:

  1. 如前所述,几乎没有改进的余地:您的任务非常需要资源,并且没有什么可做的。

  2. 编译几乎让事情变得更糟......这是意料之中的:这里没有愚蠢地使用糟糕的编程技术(例如大丑陋的循环)

  3. 如果您真的在寻求速度改进,那么使用另一种语言可能会更好,尽管我怀疑您是否能够获得明显更好的性能..

于 2014-04-19T16:23:33.730 回答