3

解决方案现已在Rcpp Gallery中上线


我从 RcppArmadillo 中的 mvtnorm 包中重新实现了 dmvnorm。我有点喜欢犰狳,但我想它也可以在普通的 Rcpp 中工作。dmvnorm 的方法基于马氏距离,所以我有一个函数,然后是多元正态密度函数。

让我向您展示我的代码:

#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]

// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::mat mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }



// [[Rcpp::export]]
arma::vec dmvnorm ( arma::mat x,  arma::mat mean,  arma::mat sigma, bool log){ 

arma::vec distval = mahalanobis_arma(x,  mean, sigma);

    double logdet = sum(arma::log(arma::eig_sym(sigma)));
    double log2pi = 1.8378770664093454835606594728112352797227949472755668;
    arma::vec logretval = -( (x.n_cols * log2pi + logdet + distval)/2  ) ;

       if(log){ 
         return(logretval);

       }else { 
       return(exp(logretval));
         }
}

所以,并没有让我非常失望:

模拟一些数据

sigma <- matrix(c(4,2,2,3), ncol=2)
x <- rmvnorm(n=5000000, mean=c(1,2), sigma=sigma, method="chol")

和基准

system.time(mvtnorm::dmvnorm(x,t(1:2),.2+diag(2),F))
   user  system elapsed 
   0.05    0.02    0.06 

system.time(dmvnorm(x,t(1:2),.2+diag(2),F))
   user  system elapsed 
   0.12    0.02    0.14 

不!!!!!!:-(

[编辑]

问题是:1)为什么 RcppArmadillo 实现比普通 R 实现慢?2) 如何创建一个优于 R 实现的 Rcpp/RcppArmadillo 实现?

[编辑 2]

我将 mahalanobis_arma 放入 mvtnorm::dmvnorm 函数中,它也变慢了。

4

1 回答 1

8

如果您想要更快地实现马氏距离,您只需重新编写算法并模仿 R 使用的算法。这非常简单

我稍微修改了您的功能mahalanobis_arma以转换murowvec.

基本上我只是将 R 代码翻译成 RcppArmadillo

mahalanobis
function (x, center, cov, inverted = FALSE, ...) 
{
    x <- if (is.vector(x)) 
        matrix(x, ncol = length(x))
    else as.matrix(x)
    x <- sweep(x, 2, center)
    if (!inverted) 
        cov <- solve(cov, ...)
    setNames(rowSums((x %*% cov) * x), rownames(x))
}
<bytecode: 0x6e5b408>
<environment: namespace:stats>

这里是

#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec Mahalanobis(arma::mat x, arma::rowvec center, arma::mat cov){
    int n = x.n_rows;
    arma::mat x_cen;
    x_cen.copy_size(x);
    for (int i=0; i < n; i++) {
        x_cen.row(i) = x.row(i) - center;
    }
    return sum((x_cen * cov.i()) % x_cen, 1);    
}


// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::rowvec mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }

现在,让我们比较这个新的犰狳版本 ( Mahalanobis)、您的第一个版本 ( mahalanobis_arma) 和 R 实现 ( mahalanobis)。

我将此 Cpp 代码保存为mahalanobis.cpp

require(RcppArmadillo)
sourceCpp("mahalanobis.cpp")

set.seed(1)
x <- matrix(rnorm(10000 * 10), ncol = 10)
Sx <- cov(x)


all.equal(c(Mahalanobis(x, colMeans(x), Sx))
          ,mahalanobis(x, colMeans(x), Sx))
## [1] TRUE

all.equal(mahalanobis_arma(x, colMeans(x), Sx)
          ,Mahalanobis(x, colMeans(x), Sx))
## [1] TRUE


require(rbenchmark)
benchmark(Mahalanobis(x, colMeans(x), Sx),
          mahalanobis(x, colMeans(x), Sx),
          mahalanobis_arma(x, colMeans(x), Sx),
          order = "elapsed")


##                                   test replications elapsed
## 1      Mahalanobis(x, colMeans(x), Sx)          100   0.124
## 2      mahalanobis(x, colMeans(x), Sx)          100   0.741
## 3 mahalanobis_arma(x, colMeans(x), Sx)          100   4.509
##   relative user.self sys.self user.child sys.child
## 1    1.000     0.173    0.077          0         0
## 2    5.976     0.804    0.670          0         0
## 3   36.363     4.386    4.626          0         0

如您所见,新的实现比 R 更快。我很确定我们可以通过使用 cholesky 分解来求解协方差矩阵或使用其他矩阵分解来做得更好。

最后,我们可以将这个Mahalanobis函数插入你dmvnorm并测试它:

require(mvtnorm)
set.seed(1)
sigma <- matrix(c(4, 2, 2, 3), ncol = 2)
x <- rmvnorm(n = 5000000, mean = c(1, 2), sigma = sigma, method = "chol")


all.equal(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          c(dmvnorm(x, t(1:2), .2+diag(2), FALSE)))
## [1] TRUE

benchmark(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          dmvnorm(x, t(1:2), .2+diag(2), FALSE),
          order = "elapsed")

##                                                test replications
## 2          dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
## 1 mvtnorm::dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
##   elapsed relative user.self sys.self user.child sys.child
## 2  35.366    1.000    31.117    4.193          0         0
## 1  60.770    1.718    56.666   13.236          0         0

它现在几乎快了两倍。

于 2013-07-12T20:17:40.400 回答