我将以下模型与 statmodels 一起使用。
import statsmodels.api as sm
class MultiLinRegLearner_A:
"""This class defines a Multi Linear regression algorithm."""
def __init__(self, X_train, Y_train):
self.X_train = sm.add_constant(X_train)
self.Y_train = Y_train
def train(self):
self.model = sm.OLS(self.Y_train, self.X_train).fit()
def query(self, X_test):
Y = self.model.predict(X_test)
return Y
mgl = MultiLinRegLearner_A(X[:150], Y[:150])
mgl.train()
mgl.query(X[150:])
在这里,直到训练部分它工作得很好,但是当我调用 mgl.query(X[150:]) 时会发生错误。我收到以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-49-c0758051747f> in <module>
2 mgl = MultiLinRegLearner_A(X[:150], Y[:150])
3 mgl.train()
----> 4 mgl.query(X[150:])
<ipython-input-48-0c0142618863> in query(self, X_test)
13
14 def query(self, X_test):
---> 15 Y = self.model.predict(X_test)
16 return Y
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\base\model.py in predict(self, exog, transform, *args, **kwargs)
1098
1099 predict_results = self.model.predict(self.params, exog, *args,
-> 1100 **kwargs)
1101
1102 if exog_index is not None and not hasattr(predict_results,
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\regression\linear_model.py in predict(self, params, exog)
378 exog = self.exog
379
--> 380 return np.dot(exog, params)
381
382 def get_distribution(self, params, scale, exog=None, dist_class=None):
<__array_function__ internals> in dot(*args, **kwargs)
ValueError: shapes (74,3) and (4,) not aligned: 3 (dim 1) != 4 (dim 0)
这是 Y 的一部分:
array([[12.67960262],
[12.61143303],
[13.20305061],
[13.14705372],
[13.06427574],
这是一块X:
array([[ 5.57469465e-02, 4.58671212e-02, 8.05328572e-01],
[ 5.95213534e-02, 2.82266197e-02, 4.60267261e-01],
[ 1.08318052e-01, 5.42789552e-02, 6.83864663e-01],
[ 1.06784149e-01, 2.89634220e-02, 4.62705431e-01],
[ 6.70113471e-02, 9.51949656e-03, 2.23702558e-01],
但是当我在一个类中使用模型时,如下所示,它完美地工作。
import statsmodels.api as sm
X = sm.add_constant(X)
model = sm.OLS(Y[:150],X[:150]).fit()
predictions = model.predict(X[150:])
print_model = model.summary()
有谁知道我该如何解决?