0

which(x,arr.ind=T)我在 Rcpp 或 RcppArmadillo 中找不到很酷的功能。所以我决定自己快速编写代码。

// [[Rcpp::export]]
arma::umat whicha(arma::mat matrix, int what ){
  arma::uvec outp1;
  int n  =   matrix.n_rows;
  outp1  =   find(matrix==what);
  int nf =   outp1.n_elem;
  arma::mat  out(nf,2);
  arma::vec  foo;
  arma::uvec foo2;
  foo = arma::conv_to<arma::colvec>::from(outp1) +1;  
  foo2 = arma::conv_to<arma::uvec>::from(foo);
  for(int i=0; i<nf; i++){
    out(i,0) = ( foo2(i) %n);
    out(i,1) =  ceil(foo(i) / n ); 
    if(out(i,0)==0) {
      out(i,0)=n;
    }
  }
  return(arma::conv_to<arma::umat>::from(out));
}

该代码似乎效率很低,但microbenchmark表明它可以比 R 的which函数更快。

问题:我可以进一步改变这个函数来真正准确地重现 R 的which函数,即传递MATRIX == something给它吗?现在我需要第二个论点。我只是为了方便而喜欢这个。


更新:修复了一个错误 - 需要 ceil 而不是 floor

如何检查:

ma=floor(abs(rnorm(100,0,6)))
testf=function(k) {all(which(ma==k,arr.ind=T) == whicha(ma,k))} ; sapply(1:10,testf)

基准:

> microbenchmark(which(ma==k,arr.ind=T) , whicha(ma,k))
Unit: microseconds
                        expr    min     lq median     uq    max neval
 which(ma == k, arr.ind = T) 10.264 11.170 11.774 12.377 51.317   100
               whicha(ma, k)  3.623  4.227  4.830  5.133 36.224   100
4

2 回答 2

1

这是我只使用 Rcpp 的代码:

src <- '
    using namespace std;

    NumericMatrix X(X_);
    double what = as<double>(what_);
    int n_rows = X.nrow();

    NumericVector rows(0);
    NumericVector cols(0);

    for(int ii = 0; ii < n_rows * n_rows; ii++)
    {
        if(X[ii] == what)
        {
            rows.push_back(ii % n_rows + 1);
            cols.push_back(floor(ii / n_rows) + 1);
        }
    }

    return List::create(rows, cols);
'

fun <- inline:::cxxfunction(signature(X_ = 'numeric', what_ = 'numeric'), src, 'Rcpp')

X <- matrix(1:1E4, nrow=1E2)

rbenchmark:::benchmark(fun(X, 100), which(X == 100L, TRUE), columns = c('test', 'replications', 'elapsed', 'relative'), replications = 1000)

                   test replications elapsed relative
1           fun(X, 100)         1000   0.077    1.000
2 which(X == 100, TRUE)         1000   0.100    1.299

microbenchmark:::microbenchmark(fun(X, 100), which(X == 100L, TRUE), times = 1000L)

                   expr    min      lq  median      uq      max neval
            fun(X, 100) 37.372 41.3780 43.6530 48.4825 1650.154  1000
 which(X == 100L, TRUE) 63.366 64.0745 64.3345 64.8240 1911.858  1000

与上一张海报中的解决方案相比,速度并不慢。有趣的是,返回数据框而不是列表会显着降低性能。

于 2013-08-28T22:08:23.947 回答
1

我会通过生成一个包装 R 函数并做一些丑陋的工作来处理调用来做到这一点。一个例子,使用你的代码:

whicha.cpp
----------

#include <RcppArmadillo.h>
// [[Rcpp::depends("RcppArmadillo")]]

// [[Rcpp::export]]
arma::umat whicha(arma::mat matrix, int what ){
  arma::uvec outp1;
  int n =   matrix.n_rows;
  outp1 =   find(matrix==what);
  int nf = outp1.n_elem;
  arma::mat out(nf,2);
  arma::vec foo;
  arma::uvec foo2;

  foo = arma::conv_to<arma::vec>::from(outp1) +1;
  out.col(1) = floor(  foo  / n ) +1; 
  foo2 = arma::conv_to<arma::uvec >::from(foo);
  for(int i=0; i<nf; i++){
    out(i,0) =  foo2(i) % n;
  }

  return(arma::conv_to<arma::umat >::from(out));
}

/*** R
whichRcpp <- function(x) {
  call <- match.call()$x
  xx <- eval.parent( call[[2]] )
  what <- eval.parent( call[[3]] )
  return( whicha(xx, what) )
}
x <- matrix(1:1E4, nrow=1E2)
identical( whichRcpp(x == 100L), whicha(x, 100L) ) ## TRUE
microbenchmark::microbenchmark( whichRcpp(x == 100L), whicha(x, 100L) )
*/

不幸microbenchmark的是,告诉我解析调用有点慢:

Unit: microseconds
                 expr    min     lq median      uq    max neval
 whichRcpp(x == 100L) 43.542 44.143 44.443 45.0440 73.271   100
      whicha(x, 100L) 30.029 30.630 30.930 31.2305 78.075   100

在 C 级别解析调用可能值得您花时间,但我会留给您。

于 2013-08-26T05:56:56.570 回答