0

我正在考虑使用 R targets,但我很难让它接受多个文件输出。

例如,我希望能够获取一个数据集,创建一个训练/测试拆分并将每个数据集写入一个单独的文件。

MWE 将是

_targets.R

library(targets)
source("R/functions.R")

set.seed(124)

list(
  # created using write.csv(mtcars, "data/mtcars.csv")
  tar_target(raw_data, "data/mtcars.csv", format = "file"),
  tar_target(data, read.csv(raw_data),
  # this throws an error here:
  tar_target(train_test, split_dataset(data), format = "file"),
# this only shows how I would try to use the train/test datasets
  tar_target(model, train_model(train_test)),
  tar_target(eval, eval_model(model, train_test))
)

其中split_dataset()定义在R/functions.R

split_dataset <- function(data) {
    idx <- sample.int(nrow(data), 0.8 * nrow(data))
    train <- data[idx, ]
    test <- data[-idx, ]
    write.csv(train, "data/train.csv")
    write.csv(test, "data/test.csv")
    return(c("data/train.csv", "data/test.csv"))
  }

一种替代方法是使用列表list(train = train, test = test),但如果可能,我希望能够访问任一数据集并将数据集保存为单独的文件。

另一种替代方法是在目标列表中定义索引,拆分数据集并将每个数据集写入单独的目标中。如果可能的话,我想将这些步骤浓缩为一个(如上所示),以使目标文件更易于理解。

4

1 回答 1

1

我建议附加idx为一列data,然后稍后针对trainandtest目标对其进行过滤。此外,您以后不需要format = "file"能够访问数据集。你可以使用tar_read()or tar_load()。草图:

library(targets)
library(tibble)

dir.create("data")
write.csv(mtcars, "data/mtcars.csv")

tar_script({
  library(tibble)
  split_data <- function(data) {
    idx <- sample.int(n = nrow(data), size = 0.8 * nrow(data))
    data$is_training <- seq_len(nrow(data)) %in% idx
    as_tibble(data)
  }
  
  list(
    tar_target(raw_data, "data/mtcars.csv", format = "file"),
    tar_target(data, split_data(read.csv(raw_data)), format = "feather"),
    tar_target(train, data[data$is_training, ], format = "feather"),
    tar_target(test, data[!data$is_training, ], format = "feather")
  )
})

tar_visnetwork()


tar_make()
#> ● run target raw_data
#> ● run target data
#> ● run target test
#> ● run target train
#> ● end pipeline

tar_read(train)
#> # A tibble: 25 x 13
#>    X             mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <chr>       <dbl> <int> <dbl> <int> <dbl> <dbl> <dbl> <int> <int> <int> <int>
#>  1 Mazda RX4    21       6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2 Mazda RX4 …  21       6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3 Datsun 710   22.8     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4 Hornet 4 D…  21.4     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5 Hornet Spo…  18.7     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6 Valiant      18.1     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7 Duster 360   14.3     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8 Merc 240D    24.4     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9 Merc 230     22.8     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10 Merc 280C    17.8     6  168.   123  3.92  3.44  18.9     1     0     4     4
#> # … with 15 more rows, and 1 more variable: is_training <lgl>

tar_read(test)
#> # A tibble: 7 x 13
#>   X              mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>   <chr>        <dbl> <int> <dbl> <int> <dbl> <dbl> <dbl> <int> <int> <int> <int>
#> 1 Merc 280      19.2     6 168.    123  3.92  3.44  18.3     1     0     4     4
#> 2 Merc 450SLC   15.2     8 276.    180  3.07  3.78  18       0     0     3     3
#> 3 Lincoln Con…  10.4     8 460     215  3     5.42  17.8     0     0     3     4
#> 4 Fiat 128      32.4     4  78.7    66  4.08  2.2   19.5     1     1     4     1
#> 5 AMC Javelin   15.2     8 304     150  3.15  3.44  17.3     0     0     3     2
#> 6 Fiat X1-9     27.3     4  79      66  4.08  1.94  18.9     1     1     4     1
#> 7 Lotus Europa  30.4     4  95.1   113  3.77  1.51  16.9     1     1     5     2
#> # … with 1 more variable: is_training <lgl>

reprex 包(v1.0.0)于 2021-03-30 创建

于 2021-03-30T16:15:06.923 回答