0

请原谅我,但我对 Rcpp 了解不多,但我试图弄清楚学习它是否会很好,以改进我正在编写的包。

我已经编写了一个 R 包,它(应该)使用 MCMC 算法从高维受限空间中高效且随机地均匀采样。它(未完成并且)位于https://github.com/davidkane9/kmatching

问题是当我运行称为 Gelman-Rubin 诊断的统计测试以查看我的 MCMC 算法是否收敛到平稳分布时,我应该得到 R = 1 的统计数据,但我得到的数字非常高,这基本上告诉我我的采样不好,没有人应该使用它。解决方案是获取更多样本并跳过更多样本(从每 1000 个样本中抽取 1 个样本,而不是每 100 个样本)。然而,这需要很多时间。如果你想运行一些代码,这里有一个例子:

##install the package first
data(lalonde)
matchvars = c("age", "educ", "black")
k = kmatch(x = lalonde, weight.var = "treat", match.var = matchvars, n = 1000, skiplength = 1000, chains = 2, verbose = TRUE)

看着这个的 Rprof 输出,我明白了,rnorm并且%*%大部分时间都在花费:

                       total.time total.pct self.time self.pct
"kmatch"                  1453.14    100.00      0.00     0.00
"hitandrun"               1450.18     99.79    128.80     8.86
"%*%"                      757.00     52.09    757.00    52.09
"cat"                      343.18     23.62    329.82    22.70
"rnorm"                    106.34      7.32    103.50     7.12
"mirror"                    35.26      2.43     21.84     1.50
"paste"                     14.02      0.96     14.02     0.96
"stdout"                    13.36      0.92     13.36     0.92
"runif"                     13.32      0.92     13.32     0.92
"/"                         12.82      0.88     12.82     0.88
">"                          7.42      0.51      7.42     0.51
"<"                          6.22      0.43      6.22     0.43
"-"                          5.78      0.40      5.78     0.40
"max"                        5.18      0.36      5.18     0.36
"nchar"                      5.12      0.35      5.12     0.35
"*"                          4.84      0.33      4.84     0.33
"min"                        3.94      0.27      3.94     0.27
"sum"                        3.42      0.24      3.42     0.24
"gelman.diag"                2.90      0.20      0.00     0.00
"=="                         2.86      0.20      2.86     0.20
"ncol"                       2.84      0.20      2.84     0.20
"apply"                      2.72      0.19      0.26     0.02
"+"                          2.48      0.17      2.48     0.17
"FUN"                        2.32      0.16      1.66     0.11
"^"                          2.08      0.14      2.08     0.14
":"                          1.24      0.09      1.24     0.09
"sqrt"                       0.96      0.07      0.96     0.07
"%%"                         0.90      0.06      0.90     0.06
"mean.default"               0.62      0.04      0.62     0.04
"lapply"                     0.40      0.03      0.26     0.02
"("                          0.32      0.02      0.32     0.02
"unlist"                     0.26      0.02      0.00     0.00
"array"                      0.12      0.01      0.02     0.00
"sapply"                     0.12      0.01      0.00     0.00
"matrix"                     0.06      0.00      0.02     0.00
"Null"                       0.04      0.00      0.04     0.00
"print"                      0.04      0.00      0.00     0.00
"unique"                     0.04      0.00      0.00     0.00
"abs"                        0.02      0.00      0.02     0.00
"all"                        0.02      0.00      0.02     0.00
"aperm.default"              0.02      0.00      0.02     0.00
"as.matrix.mcmc"             0.02      0.00      0.02     0.00
"file.exists"                0.02      0.00      0.02     0.00
"list"                       0.02      0.00      0.02     0.00
"print.default"              0.02      0.00      0.02     0.00
"stopifnot"                  0.02      0.00      0.02     0.00
"unique.default"             0.02      0.00      0.02     0.00
"which.min"                  0.02      0.00      0.02     0.00
"<Anonymous>"                0.02      0.00      0.00     0.00
"aperm"                      0.02      0.00      0.00     0.00
"as.mcmc.list"               0.02      0.00      0.00     0.00
"as.mcmc.list.default"       0.02      0.00      0.00     0.00
"data"                       0.02      0.00      0.00     0.00
"mcmc.list"                  0.02      0.00      0.00     0.00
"print.gelman.diag"          0.02      0.00      0.00     0.00
"quantile.default"           0.02      0.00      0.00     0.00
"sort"                       0.02      0.00      0.00     0.00
"sort.default"               0.02      0.00      0.00     0.00
"sort.int"                   0.02      0.00      0.00     0.00
"summary"                    0.02      0.00      0.00     0.00
"summary.default"            0.02      0.00      0.00     0.00

如果我设置verbose = F,cat则消失,但%*%需要大约70%的时间。我想知道尝试用 C++ 编写我的代码然后使用 RCpp 是否值得,或者是否因为花费这么多时间的函数是基本函数(已经用 C 编写)所以不值得我'将不得不忍受它或找到更好的算法。

编辑:根据 Rprof,阻碍我的一行u = Z %*% rhitandrun

## This is the loop that is being run millions of times and taking forever
for(i in 1:(n*skiplength+discard)) {
        tmin<-0;tmax<-0;
        ## runs counts how many times tried to pick a direction, if
        ## too high fail.
        runs = 0
        while(tmin ==0 && tmax ==0) {
          ## r is a random unit vector in with basis in Z
          r <- rnorm(ncol(Z))
          r <- r/sqrt(sum(r^2))

          ## u is a unit vector in the appropriate k-plane pointing in a
          ## random direction Z %*% r is the same as in mirror
          u <- Z%*%r
          c <- y/u
          ## determine intersections of x + t*u with walls
          ## the limits on how far you can go backward and forward
          ## i.e. the maximum and minimum ratio y_i/u_i for negative and positive u.
          tmin <- max(-c[u>0]); tmax <- min(-c[u<0]);
          ## unboundedness
          if(tmin == -Inf || tmax == Inf){
            stop("problem is unbounded")
          }
          ## if stuck on boundary point
          if(tmin==0 && tmax ==0) {
            runs = runs + 1
            if(runs >= 1000) stop("hitandrun found can't find feasible direction, cannot generate points")
          }
        }

        ## chose a point on the line segment
        y <- y + (tmin + (tmax - tmin)*runif(1))*u;

        ## choose a point every 'skiplength' samples
        if(i %% skiplength == 0) {
          X[,index] <- y
          index <- index + 1
        }
        if(verbose) for(j in 1:nchar(str)) cat("\b")
        str <- paste(i)
        if(verbose) cat(str)
      }

这实际上是我唯一一次在我的采样循环中进行矩阵乘法,但是我做了数千次,每个样本一次抽取一百万个样本并丢弃 99%。

4

1 回答 1

2

Rcpp 实际上已被大量用于此目的:MCMC。您通常会获得相当不错的速度增益,大约为 30 到 50 或 70。

早期的软件包之一是 Whit 的rcppbugs,他在使用他编写的一些类对其进行了编程后,将其转换为 Rcpp 以方便使用。随便在网络上搜索“Rcpp MCMC”会引导您找到一些帖子。

其他作者也为此使用了 Rcpp。它也位于 (R)Stan 的内部,因为您确实希望MCMC中固有的循环结构尽可能快地运行。因此编译。

上周我询问了 rcpp-devel 列表,我应该在明天的简短 R 用户组演示中讨论什么,“MCMC”建议或多或少占主导地位。还展示了另一个 RUG 的完整演讲。我会链接到该线程,但不知何故它落在了 Gmane 的 rcpp-devel 存档中。

所以总而言之,我会说是的,你确实想考虑在这里使用 Rcpp。

于 2013-08-07T17:30:03.363 回答