4

我有一个数字向量 v (已经省略了 NA)并想要获得第 n 个最大值及其各自的频率。

我发现 http://gallery.rcpp.org/articles/top-elements-from-vectors-using-priority-queue/ 非常快。

// [[Rcpp::export]]
std::vector<int> top_i_pq(NumericVector v, unsigned int n)
{

typedef pair<double, int> Elt;
priority_queue< Elt, vector<Elt>, greater<Elt> > pq;
vector<int> result;

for (int i = 0; i != v.size(); ++i) {
    if (pq.size() < n)
      pq.push(Elt(v[i], i));
    else {
      Elt elt = Elt(v[i], i);
      if (pq.top() < elt) {
        pq.pop();
        pq.push(elt);
      }
    }
  }

  result.reserve(pq.size());
  while (!pq.empty()) {
    result.push_back(pq.top().second + 1);
    pq.pop();
  }

  return result ;

}

然而,关系将不被尊重。实际上我不需要索引,返回值也可以。

我想得到的是一个包含值和频率的列表,比如:

numv <- c(4.2, 4.2, 4.5, 0.1, 4.4, 2.0, 0.9, 4.4, 3.3, 2.4, 0.1)

top_i_pq(numv, 3)
$lengths
[1] 2 2 1

$values
[1] 4.2 4.4 4.5

获取唯一向量、表格或(完整)排序都不是一个好主意,因为与 v 的长度(可能很容易 >1e6)相比,n 通常很小。

到目前为止的解决方案是:

 library(microbenchmark)
 library(data.table)
 library(DescTools)

 set.seed(1789)
 x <- sample(round(rnorm(1000), 3), 1e5, replace = TRUE)
 n <- 5

 microbenchmark(
   BaseR = tail(table(x), n),
   data.table = data.table(x)[, .N, keyby = x][(.N - n + 1):.N],
   DescTools = Large(x, n, unique=TRUE),
   Coatless = ...
 )

Unit: milliseconds
       expr       min         lq       mean     median        uq       max neval
      BaseR 188.09662 190.830975 193.189422 192.306297 194.02815 253.72304   100
 data.table  11.23986  11.553478  12.294456  11.768114  12.25475  15.68544   100
  DescTools   4.01374   4.174854   5.796414   4.410935   6.70704  64.79134   100

嗯,DescTools 仍然是最快的,但我确信 Rcpp 可以显着改进它(因为它是纯 R)!

4

3 回答 3

5

我想用另一个基于 Rcpp 的解决方案来挑战我的能力,使用上面的 1e5 长度和示例数据,它比方法快约 7 倍,比DescTools方法快约 13 倍。实现有点冗长,所以我将带头进行基准测试:data.tablexn = 5

fn.dt <- function(v, n) {
    data.table(v = v)[
      ,.N, keyby = v
      ][(.N - n + 1):.N]
}

microbenchmark(
    "DescTools" = Large(x, n, unique=TRUE),
    "top_n" = top_n(x, 5),
    "data.table" = fn.dt(x, n),
    times = 500L
)
# Unit: microseconds
#        expr      min       lq      mean   median       uq       max neval
#   DescTools 3330.527 3790.035 4832.7819 4070.573 5323.155 54921.615   500
#       top_n  566.207  587.590  633.3096  593.577  640.832  3568.299   500
#  data.table 6920.636 7380.786 8072.2733 7764.601 8585.472 14443.401   500

更新

如果您的编译器支持 C++11,您可以利用std::priority_queue::emplace(令人惊讶的)显着性能提升(与下面的 C++98 版本相比)。我不会发布这个版本,因为它大部分是相同的,除了几次调用std::moveand emplace,但这里有一个链接

针对前三个函数进行测试,并使用data.table1.9.7(比 1.9.6 快一点)产生

print(res2, order = "median", signif = 3)
# Unit: relative
#              expr  min    lq      mean median    uq   max neval  cld
#            top_n2  1.0  1.00  1.000000   1.00  1.00  1.00  1000    a   
#             top_n  1.6  1.58  1.666523   1.58  1.75  2.75  1000    b  
#         DescTools 10.4 10.10  8.512887   9.68  7.19 12.30  1000    c 
#  data.table-1.9.7 16.9 16.80 14.164139  15.50 10.50 43.70  1000    d 

top_n2C ++ 11版本 在哪里。


top_n函数实现如下:

#include <Rcpp.h>
#include <utility>
#include <queue>

class histogram {
private:
    struct paired {
        typedef std::pair<double, unsigned int> pair_t;

        pair_t pair;
        unsigned int is_set;

        paired() 
            : pair(pair_t()),
              is_set(0)
        {}

        paired(double x)
            : pair(std::make_pair(x, 1)),
              is_set(1)
        {}

        bool operator==(const paired& other) const {
            return pair.first == other.pair.first;
        }

        bool operator==(double other) const {
            return is_set && (pair.first == other);
        }

        bool operator>(double other) const {
            return is_set && (pair.first > other);
        }

        bool operator<(double other) const {
            return is_set && (pair.first < other);
        }

        paired& operator++() {
            ++pair.second;
            return *this;
        }

        paired operator++(int) {
            paired tmp(*this);
            ++(*this);
            return tmp;
        }
    };

    struct greater {
        bool operator()(const paired& lhs, const paired& rhs) const {
            if (!lhs.is_set) return false;
            if (!rhs.is_set) return true;
            return lhs.pair.first > rhs.pair.first;
        }
    };  

    typedef std::priority_queue<
        paired,
        std::vector<paired>,
        greater
    > queue_t;

    unsigned int sz;
    queue_t queue;

    void insert(double x) {
        if (queue.empty()) {
            queue.push(paired(x));
            return;
        }

        if (queue.top() > x && queue.size() >= sz) return;

        queue_t qtmp;
        bool matched = false;

        while (queue.size()) {
            paired elem = queue.top();
            if (elem == x) {
                qtmp.push(++elem);
                matched = true;
            } else {
                qtmp.push(elem);
            }
            queue.pop();
        }

        if (!matched) {
            if (qtmp.size() >= sz) qtmp.pop();
            qtmp.push(paired(x));
        }

        std::swap(queue, qtmp);
    }

public:
    histogram(unsigned int sz_) 
        : sz(sz_), 
          queue(queue_t())
    {}

    template <typename InputIt>
    void insert(InputIt first, InputIt last) {
        for ( ; first != last; ++first) {
            insert(*first);
        }
    }

    Rcpp::List get() const {
        Rcpp::NumericVector values(sz);
        Rcpp::IntegerVector freq(sz);
        R_xlen_t i = 0;

        queue_t tmp(queue);
        while (tmp.size()) {
            values[i] = tmp.top().pair.first;
            freq[i] = tmp.top().pair.second;
            ++i;
            tmp.pop();
        }

        return Rcpp::List::create(
            Rcpp::Named("value") = values,
            Rcpp::Named("frequency") = freq);
    }
};


// [[Rcpp::export]]
Rcpp::List top_n(Rcpp::NumericVector x, int n = 5) {
    histogram h(n);
    h.insert(x.begin(), x.end());
    return h.get();
} 

上面的课程有很多内容histogram,但只是触及一些关键点:

  • paired类型本质上是一个围绕 an 的包装类std::pair<double, unsigned int>,它将值与计数相关联,提供一些便利功能,例如operator++()/operator++(int)用于直接预/后递增计数,以及修改的比较运算符。
  • 该类histogram包装了一种“托管”优先级队列,从某种意义上说,它的大小std::priority_queue被限制在一个特定的值sz上。
  • 我没有使用默认std::less排序std::priority_queue,而是使用大于比较器,以便可以检查候选值std::priority_queue::top()以快速确定它们是否应该(a)被丢弃,(b)替换队列中的当前最小值,或者(c) 更新队列中现有值之一的计数。这仅是可能的,因为队列的大小被限制为 <= sz
于 2016-05-03T16:48:03.107 回答
4

我敢打赌data.table是有竞争力的:

library(data.table)

data <- data.table(v)

data[ , .N, keyby = v][(.N - n + 1):.N]

n你想得到的号码在哪里

于 2016-05-03T01:09:42.507 回答
1

注意:以前的版本复制table()了目标的功能,而不是目标。此版本已被删除,并将在异地可用。

攻击地图

以下是使用map.

C++98

首先,我们需要找到数字向量的“唯一”值。

为此,我们选择将被计为 a 的数字存储key在 a 中,并在每次观察到该数字时std::map递增。value

使用 的排序结构std::map,我们知道顶部的n数字在 的后面std::map。因此,我们使用迭代器来弹出这些元素并将它们导出到数组中。

C++11

如果一个人可以访问 C++11 编译器,另一种方法是使用std::unordered_map,它有一个O(1)用于插入和检索的大 O (O(n)如果哈希不正确)与std::map它有一个大 O 的O(log(n)).

为了获得正确的 top n,人们将使用它std::partial_sort()来这样做。

实施

C++98

#include <Rcpp.h>

// [[Rcpp::export]]
Rcpp::List top_n_map(const Rcpp::NumericVector & v, int n)
{

  // Initialize a map
  std::map<double, int> Elt;

  Elt.clear();

  // Count each element
  for (int i = 0; i != v.size(); ++i) {
    Elt[ v[i] ] += 1;
  }

  // Find out how many unique elements exist... 
  int n_obs = Elt.size();

  // If the top number, n, is greater than the number of observations,
  // then drop it.  
  if(n > n_obs ) { n = n_obs; }

  // Pop the last n elements as they are already sorted. 

  // Make an iterator to access map info
  std::map<double,int>::iterator itb = Elt.end();

  // Advance the end of the iterator up to 5.
  std::advance(itb, -n);

  // Recast for R
  Rcpp::NumericVector result_vals(n);

  Rcpp::NumericVector result_keys(n);

  unsigned int count = 0;

  // Start at the nth element and move to the last element in the map.
  for( std::map<double,int>::iterator it = itb; it != Elt.end(); ++it )
  {
    // Move them into split vectors
    result_keys(count) = it->first;
    result_vals(count) = it->second;

    count++;
  }

  return Rcpp::List::create(Rcpp::Named("lengths") = result_vals,
                            Rcpp::Named("values") = result_keys);
}

短期测试

让我们通过运行一些数据来验证它是否有效:

# Set seed for reproducibility
set.seed(1789)
x <- sample(round(rnorm(1000), 3), 1e5, replace = TRUE)
n <- 5

现在我们寻求获取发生信息:

# Call our function
top_n_map(a)

给我们:

$lengths
[1] 101 104 101 103 103

$values
[1] 2.468 2.638 2.819 3.099 3.509

基准

Unit: microseconds
       expr        min          lq        mean      median         uq        max neval
      BaseR 112750.403 115946.7175 119493.4501 117676.2840 120712.595 166067.530   100
 data.table   6583.851   6994.3665   8311.8631   7260.9385   7972.548  47482.559   100
  DescTools   3291.626   3503.5620   5047.5074   3885.4090   5057.666  43597.451   100
   Coatless   6097.237   6240.1295   6421.1313   6365.7605   6528.315   7543.271   100
nrussel_c98    513.932    540.6495    571.5362    560.0115    584.628    797.315   100
nrussel_c11    489.616    512.2810    549.6581    533.2950    553.107    961.221   100

正如我们所看到的,这个实现击败了data.table,但成为了 DescTools 和@nrussel 尝试的牺牲品。

于 2016-05-03T02:32:28.043 回答