我遇到的问题是使用SARIMAX
模型和 2 个变量在 python 中预测时间序列的一种相当简单的方法:
- 内生的:感兴趣的。
- 外生的:假设对内生变量有一些影响。
该示例使用 BTC 和 ETH 的每日价值,其中 BTC 是内生的,ETH 是内生的。
import datetime
import numpy
import numpy as np
import matplotlib.pyplot as plt
import math
import pandas as pd
import pmdarima as pm
import statsmodels.api as sm
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from datetime import date
from math import sqrt
from dateutil.relativedelta import relativedelta
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from statsmodels.tsa.statespace.sarimax import SARIMAX
import itertools
from random import random
import yfinance as yf
plt.style.use('ggplot')
使用 yahoo.finance API 获取数据的方法非常简单yf
today = datetime.datetime.today()
ticker = input('Enter your ticker: ')
df1 = yf.download(ticker, period = 'max', interval = '1d')
df1.reset_index(inplace = True)
df1
这需要手动完成 - 手动插入硬币的名称(在组合硬币方面为用户提供更多自由)。
Enter your ticker: BTC-USD
[*********************100%***********************] 1 of 1 completed
Date Open High Low Close Adj Close Volume
0 2014-09-17 465.864014 468.174011 452.421997 457.334015 457.334015 21056800
1 2014-09-18 456.859985 456.859985 413.104004 424.440002 424.440002 34483200
2 2014-09-19 424.102997 427.834991 384.532013 394.795990 394.795990 37919700
3 2014-09-20 394.673004 423.295990 389.882996 408.903992 408.903992 36863600
4 2014-09-21 408.084991 412.425995 393.181000 398.821014 398.821014 26580100
... ... ... ... ... ... ... ...
2677 2022-01-15 43101.898438 43724.671875 42669.035156 43177.398438 43177.398438 18371348298
2678 2022-01-16 43172.039062 43436.808594 42691.023438 43113.878906 43113.878906 17902097845
2679 2022-01-17 43118.121094 43179.390625 41680.320312 42250.550781 42250.550781 21690904261
2680 2022-01-18 42250.074219 42534.402344 41392.214844 42375.632812 42375.632812 22417209227
2681 2022-01-19 42365.046875 42462.070312 41248.902344 42142.539062 42142.539062 24763551744
2682 rows × 7 columns
我们的外生数据df1
也是如此。然后以相同的方式获取内生数据。
today = datetime.datetime.today()
ticker = input('Enter your ticker: ')
df2 = yf.download(ticker, period = 'max', interval = '1d')
df2.reset_index(inplace = True)
df2
Enter your ticker: ETH-USD
[*********************100%***********************] 1 of 1 completed
Date Open High Low Close Adj Close Volume
0 2017-11-09 308.644989 329.451996 307.056000 320.884003 320.884003 893249984
1 2017-11-10 320.670990 324.717987 294.541992 299.252991 299.252991 885985984
2 2017-11-11 298.585999 319.453003 298.191986 314.681000 314.681000 842300992
3 2017-11-12 314.690002 319.153015 298.513000 307.907990 307.907990 1613479936
4 2017-11-13 307.024994 328.415009 307.024994 316.716003 316.716003 1041889984
... ... ... ... ... ... ... ...
1528 2022-01-15 3309.844238 3364.537842 3278.670898 3330.530762 3330.530762 9619999078
1529 2022-01-16 3330.387207 3376.401123 3291.563721 3350.921875 3350.921875 9505934874
1530 2022-01-17 3350.947266 3355.819336 3157.224121 3212.304932 3212.304932 12344309617
1531 2022-01-18 3212.287598 3236.016113 3096.123535 3164.025146 3164.025146 13024154091
1532 2022-01-19 3163.054932 3170.838135 3055.951416 3123.905762 3123.905762 14121734144
1533 rows × 7 columns
现在是两个数据集对齐的合并步骤。
df1['Date'] = pd.to_datetime(df1['Date'])
df2['Date'] = pd.to_datetime(df2['Date'])
data = df2.merge(df1, on = 'Date', how = 'left')
看起来像这样:
Date Open High Low Close_x Adj Close Volume Close_y
0 2017-11-09 308.644989 329.451996 307.056000 320.884003 320.884003 893249984 7143.580078
1 2017-11-10 320.670990 324.717987 294.541992 299.252991 299.252991 885985984 6618.140137
2 2017-11-11 298.585999 319.453003 298.191986 314.681000 314.681000 842300992 6357.600098
3 2017-11-12 314.690002 319.153015 298.513000 307.907990 307.907990 1613479936 5950.069824
4 2017-11-13 307.024994 328.415009 307.024994 316.716003 316.716003 1041889984 6559.490234
... ... ... ... ... ... ... ... ...
1528 2022-01-15 3309.844238 3364.537842 3278.670898 3330.530762 3330.530762 9619999078 43177.398438
1529 2022-01-16 3330.387207 3376.401123 3291.563721 3350.921875 3350.921875 9505934874 43113.878906
1530 2022-01-17 3350.947266 3355.819336 3157.224121 3212.304932 3212.304932 12344309617 42250.550781
1531 2022-01-18 3212.287598 3236.016113 3096.123535 3164.025146 3164.025146 13024154091 42375.632812
1532 2022-01-19 3163.054932 3170.838135 3055.951416 3123.905762 3123.905762 14121734144 42142.539062
1533 rows × 8 columns
我只想关注 BTC 和 ETH 的收盘价:
X = data[['Close_y', 'Date']]
y = data['Close_x']
X = pd.get_dummies(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 42, shuffle = False)
# grid search
X_train = X_train.drop('Date', axis = 1)
X_test = X_test.drop('Date', axis = 1)
寻找最佳网格:
# Define the p, d and q parameters to take any value between 0 and 3 (exclusive)
p = d = q = range(0, 1)
# Generate all different combinations of p, q and q triplets
pdq = list(itertools.product(p, d, q))
# Generate all different combinations of seasonal p, q and q triplets
# put 12 in the 's' position as we have monthly data
pdqs = [(x[0], x[1], x[2], 12) for x in list(itertools.product(p, d, q))]
### Run Grid Search ###
def sarimax_gridsearch(pdq, pdqs, maxiter=5):
ans = []
for comb in pdq:
for combs in pdqs:
try:
mod = SARIMAX(y_train, exog=X_train, order=comb, seasonal_order=combs)
output = mod.fit(maxiter=maxiter)
ans.append([comb, combs, output.bic])
print('SARIMAX {} x {}12 : BIC Calculated ={}'.format(comb, combs, output.bic))
except:
continue
# Find the parameters with minimal BIC value
# Convert into dataframe
ans_df = pd.DataFrame(ans, columns=['pdq', 'pdqs', 'bic'])
# Sort and return top 5 combinations
ans_df = ans_df.sort_values(by=['bic'], ascending=True)
print(ans_df)
ans_df = ans_df.iloc[0]
return ans_df['pdq'], ans_df['pdqs']
o, s = sarimax_gridsearch(pdq, pdqs)
做出预测
# future predictions
# create Exogenous variables
df1 = df1.reset_index()
df1 = df1.set_index('Date')
df1 = df1.sort_index()
li = []
ys = ['Close']
for i in ys:
a = df1[i]
train_set, test_set = np.split(a, [int(.80 * len(a))])
model = pm.auto_arima(train_set, stepwise=True, error_action='ignore',seasonal=True, m=7)
b = model.get_params()
order = b.get('order')
s_order = b.get('seasonal_order')
model = sm.tsa.statespace.SARIMAX(a,
order=order,
seasonal_order=s_order
)
model_fit = model.fit()
start_index = data.index.max().date()+ relativedelta(days=1)
end_index = date(start_index.year, start_index.month , start_index.day+10)
forecast = model_fit.predict(start=start_index, end=end_index)
#start_index = data.shape[0]
#end_index = start_index + 12
#forecast = model_fit.predict(start=start_index, end=end_index)
li.append(forecast)
df = pd.DataFrame(li)
df = df.transpose()
df.columns = ys
df = df.reset_index()
exo = df[['Close', 'index']]
exo = exo.set_index('index')
但是当我尝试根据 做出未来预测时exo
,如下所示:
#fit the model
print(b, s)
model_best = SARIMAX(y,exog=X.drop(['Date'],1), order=o, seasonal_order=s)
model_fit = model_best.fit()
model_fit.summary()
model_fit.plot_diagnostics(figsize=(15,12))
start_index = data.shape[0]
end_index = start_index + 12
pred_uc = model_fit.forecast(steps=13, start_index = start_index, end_index = end_index, exog = exo)
future_df = pd.DataFrame({'pred' : pred_uc})
print('Forecast:')
print(future_df)
plt.rcParams["figure.figsize"] = (8, 5)
#data = data.set_index('time')
plt.plot(data['Close_x'],color = 'blue', label = 'Actual')
plt.plot(pred_uc, color = 'orange',label = 'Predicted')
plt.show()
我收到这个烦人的错误:
ValueError Traceback (most recent call last)
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\tsa\statespace\mlemodel.py in _validate_out_of_sample_exog(self, exog, out_of_sample)
1757 try:
-> 1758 exog = exog.reshape(required_exog_shape)
1759 except ValueError:
ValueError: cannot reshape array of size 11 into shape (13,1)
ValueError: Provided exogenous values are not of the appropriate shape. Required (13, 1), got (11, 1).
有人可以解释我错在哪里或者我在这个模块中错过了哪些步骤吗?