2

[这也在multidplyr github页面上报道]

我正在尝试将 multidplyr_0.0.0.9000 与 purrr_0.2.4.9000 中的 dplyr_0.7.4.9000 和 pmap_dfr 一起使用。以下代码(不使用 multidplyr)工作正常:

grid1 = as_tibble(expand.grid(m1 = c(1:10), m2 = c(20:30)))
retstuff = function(m1, m2) { return(tribble(~m3, ~m4, m1+1, m2+2)) }
pmap_dfr(grid1, retstuff)

当我尝试使用 multidplyr 对网格进行分区时:

grid2 = partition(grid1, m1)
pmap_dfr(grid2, retstuff)

Error: Element 5 is not a vector (environment)从 pmap_dfr()得到错误

我还从 partition() 收到以下警告,正如 github 上所报告的那样group_indices_.grouped_df ignores extra arguments。不确定这是否相关。

4

1 回答 1

3

几个问题:

  • 您需要在每个节点上加载任何必要的包(除了 dplyr),
  • 您需要将函数复制到每个节点,并且
  • 您只能在分区数据帧上调用 dplyr 动词,因此您需要将pmap_dfr调用包装在dplyr::do

之后它的工作原理:

library(tidyverse)
library(multidplyr)

grid1 <- as_tibble(expand.grid(m1 = c(1:10), m2 = c(20:30)))
retstuff <- function(m1, m2) { 
    tribble(   ~m3,    ~m4, 
            m1 + 1, m2 + 2)
}

grid2 <- partition(grid1, m1)
#> Initialising 7 core cluster.
#> Warning: group_indices_.grouped_df ignores extra arguments
cluster_library(grid2, 'tidyverse')    # load packages on each node
cluster_copy(grid2, retstuff)    # copy function to each node

grid2 %>% do(pmap_dfr(., retstuff))    # wrap call in dplyr::do
#> Source: party_df [110 x 3]
#> Groups: m1
#> Shards: 7 [11--22 rows]
#> 
#> # S3: party_df
#>       m1    m3    m4
#>    <int> <dbl> <dbl>
#>  1     9    10    22
#>  2     9    10    23
#>  3     9    10    24
#>  4     9    10    25
#>  5     9    10    26
#>  6     9    10    27
#>  7     9    10    28
#>  8     9    10    29
#>  9     9    10    30
#> 10     9    10    31
#> # ... with 100 more rows

...但是对于这种特殊情况,虽然 multidplyr 快一点,但 plaindplyr::mutate快得多,而且更容易编写:

grid1 %>% mutate(m3 = m1 + 1, m4 = m2 + 2)
#> # A tibble: 110 x 4
#>       m1    m2    m3    m4
#>    <int> <int> <dbl> <dbl>
#>  1     1    20     2    22
#>  2     2    20     3    22
#>  3     3    20     4    22
#>  4     4    20     5    22
#>  5     5    20     6    22
#>  6     6    20     7    22
#>  7     7    20     8    22
#>  8     8    20     9    22
#>  9     9    20    10    22
#> 10    10    20    11    22
#> # ... with 100 more rows

all.equal(grid2 %>% do(pmap_dfr(., retstuff)) %>% collect, 
          grid1 %>% mutate(m3 = m1 + 1, m4 = m2 + 2) %>% select(-m2))
#> [1] TRUE

microbenchmark::microbenchmark(
    multidplyr_pmap = grid2 %>% do(pmap_dfr(., retstuff)) %>% collect(), 
    multidplyr_mutate = grid2 %>% mutate(m3 = m1 + 1, m4 = m2 + 2) %>% collect(),
    pmap = grid1 %>% pmap_dfr(retstuff),
    mutate = grid1 %>% mutate(m3 = m1 + 1, m4 = m2 + 2) %>% select(-m2)
)
#> Unit: milliseconds
#>               expr        min        lq       mean    median         uq       max neval
#>    multidplyr_pmap 113.896646 117.18365 122.656286 119.75652 125.874450 182.53330   100
#>  multidplyr_mutate  12.419918  12.84528  16.271337  13.68441  15.092482 177.77372   100
#>               pmap 372.512544 387.49371 397.844622 394.71971 402.640281 551.78633   100
#>             mutate   7.014426   7.49689   8.499588   7.66554   8.654478  32.22647   100  
于 2017-11-02T01:26:22.313 回答