1

我试图找出 R 中两条曲线相交的坐标。输入数据是两条曲线的经验点的坐标。我的解决方案是使用函数curve_intersect()。我需要为 2000 次复制(即 2000 对曲线)执行此操作。所以我把数据放在两个列表中。每个列表包含 1000 个数据帧,每个数据帧中有一条曲线的 x 和 y 坐标。

这是我的数据:数据

下面是我使用的代码。

threshold_or1 <- map2_df(recall_or1_4, precision_or1_4,
                         ~curve_intersect(.x, .y, empirical = TRUE, domain = NULL))

# recall_or_4 is a list of 2000 data frames. Each data frame 
# |contains coordinates from curve #1. 

# precision_or_4 is a list of 2000 data frames. Each data frame 
# |contains coordinates from curve #2.

我在下面收到此错误消息。

Error in uniroot(function(x) curve1_f(x) - curve2_f(x), c(min(curve1$x),  : f() values at end points not of opposite sign

由于函数 curve_intersect() 可以成功地应用于两个列表中的某些单独的数据帧。我运行了以下代码,以便准确查看是哪对数据帧导致进程失败。

test <- for (i in 1:2000){
            curve_intersect(recall_or1_4[[i]], precision_or1_4[[i]], empirical = TRUE, domain = NULL)
            print(paste("i=",i))}

然后,我收到以下消息,这意味着该进程成功运行,直到它到达数据对#460。所以我检查了那个单独的数据对。

[1] "i= 457"
[1] "i= 458"
[1] "i= 459"
Error in uniroot(function(x) curve1_f(x) - curve2_f(x), c(min(curve1$x),  : f() values at end points not of opposite sign

我绘制了数据对#460。

test1 <- precision_or1_4[[460]] %>% mutate(statistics = 'precision')
test2 <- recall_or1_4[[460]] %>% mutate(statistics = 'recall')
test3 <- rbind(test1, test2)
test3 <- test3 %>% mutate(statistics = as.factor(statistics))
curve_test3 <- ggplot(test3, aes(x = x, y = y))+
        geom_line(aes(colour = statistics))
curve_test3

查找交点的坐标

然后我去修改curve_intersect()的源代码。原始源代码是

    curve_intersect <- function(curve1, curve2, empirical=TRUE, domain=NULL) {
        if (!empirical & missing(domain)) {
                stop("'domain' must be provided with non-empirical curves")
        }
        
        if (!empirical & (length(domain) != 2 | !is.numeric(domain))) {
                stop("'domain' must be a two-value numeric vector, like c(0, 10)")
        }
        
        if (empirical) {
                # Approximate the functional form of both curves
                curve1_f <- approxfun(curve1$x, curve1$y, rule = 2)
                curve2_f <- approxfun(curve2$x, curve2$y, rule = 2)
                
                # Calculate the intersection of curve 1 and curve 2 along the x-axis
                point_x <- uniroot(function(x) curve1_f(x) - curve2_f(x),
                                   c(min(curve1$x), max(curve1$x)))$root
                
                # Find where point_x is in curve 2
                point_y <- curve2_f(point_x)
        } else {
                # Calculate the intersection of curve 1 and curve 2 along the x-axis
                # within the given domain
                point_x <- uniroot(function(x) curve1(x) - curve2(x), domain)$root
                
                # Find where point_x is in curve 2
                point_y <- curve2(point_x)
        }
        
        return(list(x = point_x, y = point_y))
}

我修改了uniroot()第三个 if 语句的部分。c(min(curve1$x), max(curve1$x))我没有用作的参数,而是uniroot()使用lower = -100000000, upper = 100000000. 修改后的函数是

curve_intersect_tq <- function(curve1, curve2, empirical=TRUE, domain=NULL) {
        if (!empirical & missing(domain)) {
                stop("'domain' must be provided with non-empirical curves")
        }
        
        if (!empirical & (length(domain) != 2 | !is.numeric(domain))) {
                stop("'domain' must be a two-value numeric vector, like c(0, 10)")
        }
        
        if (empirical) {
                # Approximate the functional form of both curves
                curve1_f <- approxfun(curve1$x, curve1$y, rule = 2)
                curve2_f <- approxfun(curve2$x, curve2$y, rule = 2)
                
                # Calculate the intersection of curve 1 and curve 2 along the x-axis
                point_x <- uniroot(function(x) curve1_f(x) - curve2_f(x),
                                   lower = -100000000, upper = 100000000)$root
                
                # Find where point_x is in curve 2
                point_y <- curve2_f(point_x)
        } else {
                # Calculate the intersection of curve 1 and curve 2 along the x-axis
                # within the given domain
                point_x <- uniroot(function(x) curve1(x) - curve2(x), domain)$root
                
                # Find where point_x is in curve 2
                point_y <- curve2(point_x)
        }
        
        return(list(x = point_x, y = point_y))
}

我试图改变lower =, upper =参数的值。那没起效。我收到了相同的错误消息,如下所示。

curve_intersect_tq(recall_or1_4[[460]], precision_or1_4[[460]], empirical = TRUE, domain = NULL)

Error in uniroot(function(x) curve1_f(x) - curve2_f(x), c(min(curve1$x),  : 
  f() values at end points not of opposite sign

我还尝试possibly(fun, NA)从 tidyverse 包中使用,希望该过程即使出现错误消息也可以运行。我用的时候没用

(1)possibly(curve_intersect(), NA)或 (2)possibly(uniroot(), NA)

出现了相同的错误消息。

为什么我有错误信息?有什么可能的解决方案?提前致谢。

4

1 回答 1

0

聚会可能有点晚了,但这就是您的代码仍然失败的原因以及您可以做什么,具体取决于您想从分析中得到什么:

首先,即使在改编之后,您的代码也会失败的原因是您只是告诉unirootx. 然而,底层曲线永远不会相交——只是找不到任何东西curve1_f(x) - curve2_f(x) == 0

来自以下文档uniroot

“端点处的函数值必须是相反的符号(或零),对于 extendInt="no",默认值。”

在原始curve_intersect实现中,uniroot正在搜索数据中定义的 x 间隔(即c(min(curve1$x), max(curve1$x)))。在您的更改中,您告诉它在 x 间隔中搜索[-100000000, 100000000]。你也可以设置extendInt = "yes",但它不会改变任何东西。
问题不在于搜索间隔,而在于approxfun

approxfun仅通过在点之间插入经验数据来帮助您。在您传入的数据之外,返回的函数不知道该做什么。
approxfun允许您指定y应在经验定义的窗口(及其参数yleft/ yright)之外返回的显式值,或者允许您rule为每一侧设置一个。
在您上面发布的代码中,rule = 2决定“使用最接近数据极值的值”。因此,approxfun不会推断您传入的数据。它只会扩展已知数据。

我们可以绘制如何curve1_f以及curve2_f将如何在经验定义的 x 区间之外扩展到无穷大:

tibble(
    x = seq(0, 1, by = 0.001),
    curve1_approxed = curve1_f(x),
    curve2_approxed = curve2_f(x)
  ) %>%
  pivot_longer(starts_with("curve"), names_to = "curve", values_to = "y") %>%
  ggplot(aes(x = x, y = y, color = curve)) +
  geom_line() +
  geom_vline(xintercept = c(min(curve1$x), max(curve1$x)), color = "grey75")

在经验定义的 x 区间之外的 approxfun


所以,现在你可以做些什么来让你的代码不崩溃:(
剧透:这在很大程度上取决于你试图用你的项目完成什么)

  1. 接受在观察到的数据限制中没有交集。
    如果您不想做任何假设,我建议您将映射函数包装在一个tryCatch语句中,并让它在开箱即用的解决方案没有给您任何结果的情况下失败。让我们为之前使整个事情崩溃的列表部分运行它:
threshold_or1.fix1 <- map2_df(
  recall_or1_4, precision_or1_4,
  ~tryCatch({
    curve_intersect(.x, .y, empirical = TRUE, domain = NULL)
  }, error = function(e){
    return(tibble(.rows = 1))
  }),
  .id = "i"
)

curve_intersect现在,当无法给您结果时,只有一个 NA 行。

threshold_or1.fix1[459:461,]
# A tibble: 3 x 3
  i          x      y
  <chr>  <dbl>  <dbl>
1 459    0.116  0.809
2 460   NA     NA    
3 461    0.264  0.773
  1. 尝试用线性模型推断
    您的数据 在这种情况下,我们将使用自定义curve_intersect函数。uniroot让我们将有问题的调用包装在 a 中tryCatch,如果找不到根,我们将为lm每条曲线拟合 a 并uniroot在拟合的线性上找到一个交点。
    根据你的实验,这可能有意义,也可能没有意义,所以我会让你在这里做法官。显然,lm如果您的数据比这更复杂,您可以使用其他模型而不是简单模型......
    只是为了可视化这种方法与默认方法:
tibble(
    x = seq(-1, 2, by = 0.001),
    curve1_approxed = curve1_f(x),
    curve2_approxed = curve2_f(x),
    curve1_lm = predict(lm(y ~ x, data = curve1), newdata = tibble(x = x)),
    curve2_lm = predict(lm(y ~ x, data = curve2), newdata = tibble(x = x))
  ) %>%
  pivot_longer(starts_with("curve"), names_to = "curve", values_to = "y") %>%
  ggplot(aes(x = x, y = y, color = curve)) +
  geom_line() +
  geom_vline(xintercept = c(min(curve1$x), max(curve1$x)), color = "grey75")

在经验定义的 x 区间之外的 approxfun 与 lm
你看,在approxfun“失败”的地方,lm我们假设我们可以线性推断并在x = 1.27你观察到的框架之外找到一个交叉点。

要采用第二种方法并lm在我们的搜索中包含推断,您可以将这样的内容放在一起:(
这里也只if编辑了第三种。)

curve_intersect_custom <- function(curve1, curve2, empirical=TRUE, domain=NULL) {
  if (!empirical & missing(domain)) {
    stop("'domain' must be provided with non-empirical curves")
  }
  
  if (!empirical & (length(domain) != 2 | !is.numeric(domain))) {
    stop("'domain' must be a two-value numeric vector, like c(0, 10)")
  }
  
  if (empirical) {
    
    return(
      tryCatch({
        # Approximate the functional form of both curves
        curve1_f <- approxfun(curve1$x, curve1$y, rule = 2)
        curve2_f <- approxfun(curve2$x, curve2$y, rule = 2)
        
        # Calculate the intersection of curve 1 and curve 2 along the x-axis
        point_x <- uniroot(
          f = function(x) curve1_f(x) - curve2_f(x),
          interval = c(min(curve1$x), max(curve1$x))
        )$root
        
        # Find where point_x is in curve 2
        point_y <- curve2_f(point_x)
        
        return(list(x = point_x, y = point_y, method = "approxfun"))
        
      }, error = function(e) {
        tryCatch({
          curve1_lm_f <- function(x) predict(lm(y ~ x, data = curve1), newdata = tibble(x = x))
          curve2_lm_f <- function(x) predict(lm(y ~ x, data = curve2), newdata = tibble(x = x))
          
          point_x <- uniroot(
            f = function(x) curve1_lm_f(x) - curve2_lm_f(x),
            interval = c(min(curve1$x), max(curve1$x)),
            extendInt = "yes"
          )$root
          
          point_y <- curve2_lm_f(point_x)
          
          return(list(x = point_x, y = point_y, method = "lm"))
          
        }, error = function(e) {
          return(list(x = NA_real_, y = NA_real_, method = NA_character_))
        })
      })
    )
    
    
  } else {
    # Calculate the intersection of curve 1 and curve 2 along the x-axis
    # within the given domain
    point_x <- uniroot(function(x) curve1(x) - curve2(x), domain)$root
    
    # Find where point_x is in curve 2
    point_y <- curve2(point_x)
  }
  
  return(list(x = point_x, y = point_y))
}

对于有问题的列表元素,现在尝试使用天真的拟合lm模型进行推断:

threshold_or1.fix2 <- map2_df(
    recall_or1_4, precision_or1_4,
    ~curve_intersect_custom(.x, .y, empirical = TRUE, domain = NULL),
    .id = "i"
)

threshold_or1.fix2[459:461,]
# A tibble: 3 x 4
  i         x     y method   
  <chr> <dbl> <dbl> <chr>    
1 459   0.116 0.809 approxfun
2 460   1.27  0.813 lm       
3 461   0.264 0.773 approxfun

希望这有助于理解和解决您的问题:)

于 2020-09-19T20:40:40.597 回答