0

I'm running this code and I got an error witht the fit function

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis(shrinkage='auto')
lda.fit(np.random.rand(3,2),np.random.randint((1,1,1)))

Here is the error :

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-34-ec552dd1faa1> in <module>
      1 lda = LinearDiscriminantAnalysis(shrinkage='auto')
----> 2 lda.fit(np.random.rand(3,2),np.random.randint((1,1,1)))
      3 LinearDiscriminantAnalysis()

~/anaconda3/lib/python3.8/site-packages/sklearn/discriminant_analysis.py in fit(self, X, y)
    581         if self.solver == "svd":
    582             if self.shrinkage is not None:
--> 583                 raise NotImplementedError("shrinkage not supported")
    584             if self.covariance_estimator is not None:
    585                 raise ValueError(

NotImplementedError: shrinkage not supported

How to fix it? (got the same error upgrading scikit learn, and also on google collab)

4

1 回答 1

1

shrinkagesvd求解器不支持。您可以将此参数与其他求解器一起使用,例如eigenlsqr如下:

LinearDiscriminantAnalysis(solver='lsqr',shrinkage='auto').fit(X_train, y_train)
于 2021-10-21T11:30:38.417 回答