1

我将以下模型与 statmodels 一起使用。

import statsmodels.api as sm

class MultiLinRegLearner_A:
    
    """This class defines a Multi Linear regression algorithm."""
    
    def __init__(self, X_train, Y_train):
        self.X_train = sm.add_constant(X_train)
        self.Y_train = Y_train
        
    def train(self):
        self.model = sm.OLS(self.Y_train, self.X_train).fit()  
        
    def query(self, X_test):
        Y = self.model.predict(X_test)
        return Y
mgl = MultiLinRegLearner_A(X[:150], Y[:150])
mgl.train()
mgl.query(X[150:])

在这里,直到训练部分它工作得很好,但是当我调用 mgl.query(X[150:]) 时会发生错误。我收到以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-49-c0758051747f> in <module>
      2 mgl = MultiLinRegLearner_A(X[:150], Y[:150])
      3 mgl.train()
----> 4 mgl.query(X[150:])

<ipython-input-48-0c0142618863> in query(self, X_test)
     13 
     14     def query(self, X_test):
---> 15         Y = self.model.predict(X_test)
     16         return Y

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\base\model.py in predict(self, exog, transform, *args, **kwargs)
   1098 
   1099         predict_results = self.model.predict(self.params, exog, *args,
-> 1100                                              **kwargs)
   1101 
   1102         if exog_index is not None and not hasattr(predict_results,

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\regression\linear_model.py in predict(self, params, exog)
    378             exog = self.exog
    379 
--> 380         return np.dot(exog, params)
    381 
    382     def get_distribution(self, params, scale, exog=None, dist_class=None):

<__array_function__ internals> in dot(*args, **kwargs)

ValueError: shapes (74,3) and (4,) not aligned: 3 (dim 1) != 4 (dim 0)

这是 Y 的一部分:

array([[12.67960262],
       [12.61143303],
       [13.20305061],
       [13.14705372],
       [13.06427574],

这是一块X:

array([[ 5.57469465e-02,  4.58671212e-02,  8.05328572e-01],
       [ 5.95213534e-02,  2.82266197e-02,  4.60267261e-01],
       [ 1.08318052e-01,  5.42789552e-02,  6.83864663e-01],
       [ 1.06784149e-01,  2.89634220e-02,  4.62705431e-01],
       [ 6.70113471e-02,  9.51949656e-03,  2.23702558e-01],

但是当我在一个类中使用模型时,如下所示,它完美地工作。

import statsmodels.api as sm

X = sm.add_constant(X)
model = sm.OLS(Y[:150],X[:150]).fit()
predictions = model.predict(X[150:])
print_model = model.summary()

有谁知道我该如何解决?

4

0 回答 0