MahalanobisDistance
期望一个参数V
,它是协方差矩阵,并且可选地另一个参数VI
是协方差矩阵的逆矩阵。此外,这两个参数都是命名的,而不是位置的。
还要检查sklearn repoMahalanobisDistance
文件scikit-learn/sklearn/neighbors/dist_metrics.pyx
中类的文档字符串。
例子:
In [18]: import numpy as np
In [19]: from sklearn.datasets import make_classification
In [20]: from sklearn.neighbors import DistanceMetric
In [21]: X, y = make_classification()
In [22]: DistanceMetric.get_metric('mahalanobis', V=np.cov(X))
Out[22]: <sklearn.neighbors.dist_metrics.MahalanobisDistance at 0x107aefa58>
编辑:
由于某些原因(错误?),您不能将距离对象传递给NearestNeighbor
构造函数,而是需要使用距离度量的名称。此外,设置algorithm='auto'
(默认为'ball_tree'
)似乎不起作用;所以X
从上面的代码给出你可以这样做:
In [23]: nn = NearestNeighbors(algorithm='brute',
metric='mahalanobis',
metric_params={'V': np.cov(X)})
# returns the 5 nearest neighbors of that sample
In [24]: nn.fit(X).kneighbors(X[0, :])
Out[24]: (array([[ 0., 3.21120892, 3.81840748, 4.18195987, 4.21977517]]),
array([[ 0, 36, 46, 5, 17]]))