我正在使用 fable 和 future 尝试并行预测,不幸的是,对于 for 循环中的每次迭代,该model()
步骤似乎需要更多时间并消耗更多内存。我正在尝试做的是一次向前迈出一周,并在每一步使用可能同时使用多个模型预测几周。
我传递给model()
函数的数据大小每步增加不到 1%,但计算所需的时间却呈指数级增长。下面是一个简化的示例,在我的情况下,我对直到该点的值进行一些计算并将其传递给模型,这使得每次model()
调用时计算时间的增加只会增加。
我做了一些调查,时间增加似乎来自这条线。fabletools
我在包中运行了调试选项,future
计算时间增加的相关代码就是这个。
我相信,对于循环的每次后续迭代,都会将比需要更多的数据传递给每个集群节点。有没有办法可以避免这种情况并确保仅cur_training_data
向下传递堆栈?
或者,也许我这样做的整个策略是关闭的,我看到 tsibble_stretch 可能是一种方法,但我担心复制每个时间步的训练数据会大大增加足迹,这就是我选择的原因循环和过滤。一般来说,有没有更好的方法来做到这一点?
非常感谢您的阅读。
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
library(dtplyr)
library(tidyr)
library(tsibbledata)
library(fable)
#> Loading required package: fabletools
library(tsibble)
library(logger)
library(future)
library(tidyquant)
#> Loading required package: lubridate
#>
#> Attaching package: 'lubridate'
#> The following object is masked from 'package:tsibble':
#>
#> interval
#> The following objects are masked from 'package:base':
#>
#> date, intersect, setdiff, union
#> Loading required package: PerformanceAnalytics
#> Loading required package: xts
#> Loading required package: zoo
#>
#> Attaching package: 'zoo'
#> The following object is masked from 'package:tsibble':
#>
#> index
#> The following objects are masked from 'package:base':
#>
#> as.Date, as.Date.numeric
#>
#> Attaching package: 'xts'
#> The following objects are masked from 'package:dplyr':
#>
#> first, last
#>
#> Attaching package: 'PerformanceAnalytics'
#> The following object is masked from 'package:graphics':
#>
#> legend
#> Loading required package: quantmod
#> Loading required package: TTR
#> Registered S3 method overwritten by 'quantmod':
#> method from
#> as.zoo.data.frame zoo
#> Version 0.4-0 included new data defaults. See ?getSymbols.
#> == Need to Learn tidyquant? =======================================================================================================================
#> Business Science offers a 1-hour course - Learning Lab #9: Performance Analysis & Portfolio Optimization with tidyquant!
#> </> Learn more at: https://university.business-science.io/p/learning-labs-pro </>
#>
#> Attaching package: 'tidyquant'
#> The following object is masked from 'package:fable':
#>
#> VAR
library(tictoc)
# Set up some variables
my_vars <- NULL
# Value variable
my_vars$unit_column_name <- "adjusted"
# Whether to run in parallel or not
my_vars$run_parallel <- TRUE
# Number of cycles to forecast for
my_vars$num_cycles_to_forecast <- 9
# Weeks to predict for in each cycle
my_vars$weeks_to_predict <- 13
# Number of stock symbols
my_vars$num_stock_symbols <- 200
# Get stock data
stocks <- tq_index("SP500")
#> Getting holdings for SP500
symbols <- stocks %>% pull(symbol)
# Get stock price data
stock_prices <- tq_get(symbols[1:my_vars$num_stock_symbols],
get = "stock.prices",
from = "2015-01-01")
#> Warning: Problem with `mutate()` input `data..`.
#> i x = 'BRK.B', get = 'stock.prices': Error in getSymbols.yahoo(Symbols = "BRK.B", env = <environment>, verbose = FALSE, : Unable to import "BRK.B".
#> BRK.B download failed after two attempts. Error message:
#> HTTP error 404.
#> Removing BRK.B.
#> i Input `data..` is `purrr::map(...)`.
#> Warning: x = 'BRK.B', get = 'stock.prices': Error in getSymbols.yahoo(Symbols = "BRK.B", env = <environment>, verbose = FALSE, : Unable to import "BRK.B".
#> BRK.B download failed after two attempts. Error message:
#> HTTP error 404.
#> Removing BRK.B.
# Convert to tsibble
stock_prices_tsibble <- stock_prices %>%
as_tsibble(key = c(symbol),
index = date)
# Prepare the training data, just need weekly data
weekly_stocks_tsibble <- stock_prices_tsibble %>%
setNames(tolower(names(.))) %>%
rename(forecast_week = date) %>%
select(symbol, forecast_week, adjusted) %>%
mutate(weekday = lubridate::wday(forecast_week)) %>%
filter(weekday == 6) %>%
select(-weekday) %>%
as_tsibble(key = c(symbol),
index = forecast_week) %>%
fill_gaps(!!as.name({{my_vars$unit_column_name}}) := 0)
# Get the cycles we want to forecast for
my_vars$cycles_to_forecast <- weekly_stocks_tsibble %>%
slice_tail(n = my_vars$num_cycles_to_forecast) %>%
pull(forecast_week)
run_my_models <- function(cycles_to_forecast, actuals_column_name, weeks_to_predict, run_parallel, actuals_tsibble, ...) {
#
if (run_parallel == TRUE) {
plan(multiprocess)
}
# Tibble to hold results
log_info("Creating holder tibble")
holder_data_frame <- tibble()
# Fit through cycles
for (i in 1:length(cycles_to_forecast)) {
# i <- 1
cur_cycle <- cycles_to_forecast[i]
log_info("Running for cycle {i}/{length(cycles_to_forecast)}: {cur_cycle}")
# Prepare current cycles training data
cur_training_data <- actuals_tsibble %>%
filter(forecast_week < cur_cycle)
# Check that there are rows in the training data
if(nrow(cur_training_data) <= 0) {
warn("No rows in current cycle training data")
next()
}
log_info("Training data: {min(cur_training_data$forecast_week)} - {max(cur_training_data$forecast_week)}")
tic()
# Fit models
cur_fit <- cur_training_data %>%
model(...)
log_info("Models fitted")
toc()
# Predict
predictions <- cur_fit %>%
forecast(h = my_vars$weeks_to_predict,
point_forecast = list(forecasted_units = mean))
log_info("Predictions generated")
# Colect useful prediction information
cur_fit_formatted <- cur_fit %>%
as_tibble() %>%
# mutate_if(~!is.character(.), print) %>%
pivot_longer(cols = -c(symbol),
names_to = "method",
values_to = "method_specifics") %>%
lazy_dt()
collected_predictions <- predictions %>%
as_tibble() %>%
lazy_dt() %>%
rename(method = .model) %>%
left_join(cur_fit_formatted, by = c("symbol", "method")) %>%
mutate(forecast_cycle = cur_cycle) %>%
select(symbol, forecast_cycle, forecast_week, forecasted_units, method, method_specifics)
log_info("Predictions colected")
holder_data_frame <- holder_data_frame %>%
bind_rows(as_tibble(collected_predictions))
}
return(holder_data_frame)
}
model_predictions <- run_my_models(cycles_to_forecast = my_vars$cycles_to_forecast,
actuals_column_name = my_vars$unit_column_name,
weeks_to_predict = my_vars$weeks_to_predict,
run_parallel = my_vars$run_parallel,
actuals_tsibble = weekly_stocks_tsibble,
# Model definitions
arima = ARIMA(!!as.name(my_vars$unit_column_name)))
#> INFO [2020-09-03 08:25:35] Creating holder tibble
#> INFO [2020-09-03 08:25:35] Running for cycle 1/9: 2020-07-03
#> INFO [2020-09-03 08:25:35] Training data: 2015-01-02 - 2020-06-26
#> INFO [2020-09-03 08:26:08] Models fitted
#> 33.27 sec elapsed
#> INFO [2020-09-03 08:26:10] Predictions generated
#> INFO [2020-09-03 08:26:10] Predictions colected
#> INFO [2020-09-03 08:26:11] Running for cycle 2/9: 2020-07-10
#> INFO [2020-09-03 08:26:11] Training data: 2015-01-02 - 2020-07-03
#> INFO [2020-09-03 08:26:42] Models fitted
#> 30.15 sec elapsed
#> INFO [2020-09-03 08:26:44] Predictions generated
#> INFO [2020-09-03 08:26:44] Predictions colected
#> INFO [2020-09-03 08:26:44] Running for cycle 3/9: 2020-07-17
#> INFO [2020-09-03 08:26:44] Training data: 2015-01-02 - 2020-07-10
#> INFO [2020-09-03 08:27:35] Models fitted
#> 50.63 sec elapsed
#> INFO [2020-09-03 08:27:37] Predictions generated
#> INFO [2020-09-03 08:27:37] Predictions colected
#> INFO [2020-09-03 08:27:38] Running for cycle 4/9: 2020-07-24
#> INFO [2020-09-03 08:27:38] Training data: 2015-01-02 - 2020-07-17
#> INFO [2020-09-03 08:28:43] Models fitted
#> 64.41 sec elapsed
#> INFO [2020-09-03 08:28:45] Predictions generated
#> INFO [2020-09-03 08:28:45] Predictions colected
#> INFO [2020-09-03 08:28:45] Running for cycle 5/9: 2020-07-31
#> INFO [2020-09-03 08:28:45] Training data: 2015-01-02 - 2020-07-24
#> INFO [2020-09-03 08:30:06] Models fitted
#> 81.08 sec elapsed
#> INFO [2020-09-03 08:30:09] Predictions generated
#> INFO [2020-09-03 08:30:09] Predictions colected
#> INFO [2020-09-03 08:30:09] Running for cycle 6/9: 2020-08-07
#> INFO [2020-09-03 08:30:09] Training data: 2015-01-02 - 2020-07-31
#> INFO [2020-09-03 08:31:55] Models fitted
#> 105.32 sec elapsed
#> INFO [2020-09-03 08:31:57] Predictions generated
#> INFO [2020-09-03 08:31:57] Predictions colected
#> INFO [2020-09-03 08:31:57] Running for cycle 7/9: 2020-08-14
#> INFO [2020-09-03 08:31:57] Training data: 2015-01-02 - 2020-08-07
#> INFO [2020-09-03 08:34:00] Models fitted
#> 123.16 sec elapsed
#> INFO [2020-09-03 08:34:02] Predictions generated
#> INFO [2020-09-03 08:34:02] Predictions colected
#> INFO [2020-09-03 08:34:02] Running for cycle 8/9: 2020-08-21
#> INFO [2020-09-03 08:34:02] Training data: 2015-01-02 - 2020-08-14
#> INFO [2020-09-03 08:36:27] Models fitted
#> 144.39 sec elapsed
#> INFO [2020-09-03 08:36:29] Predictions generated
#> INFO [2020-09-03 08:36:29] Predictions colected
#> INFO [2020-09-03 08:36:29] Running for cycle 9/9: 2020-08-28
#> INFO [2020-09-03 08:36:29] Training data: 2015-01-02 - 2020-08-21
#> INFO [2020-09-03 08:39:06] Models fitted
#> 156.76 sec elapsed
#> INFO [2020-09-03 08:39:08] Predictions generated
#> INFO [2020-09-03 08:39:08] Predictions colected
sessionInfo()
R version 4.0.2 (2020-06-22)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows Server x64 (build 14393)
Matrix products: default
locale:
[1] LC_COLLATE=English_Ireland.1252 LC_CTYPE=English_Ireland.1252 LC_MONETARY=English_Ireland.1252 LC_NUMERIC=C
[5] LC_TIME=English_Ireland.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
loaded via a namespace (and not attached):
[1] ps_1.3.4 digest_0.6.25 crayon_1.3.4 R6_2.4.1 lifecycle_0.2.0 reprex_0.3.0 magrittr_1.5 evaluate_0.14
[9] pillar_1.4.4 rlang_0.4.7 rstudioapi_0.11 fs_1.5.0 callr_3.4.3 whisker_0.4 vctrs_0.3.3 ellipsis_0.3.1
[17] rmarkdown_2.3 tools_4.0.2 processx_3.4.3 xfun_0.16 compiler_4.0.2 pkgconfig_2.0.3 clipr_0.7.0 htmltools_0.5.0
[25] knitr_1.29 tibble_3.0.3