我正在尝试使用scikit示例代码应用 Local Outlier Factor 算法将异常值检测到 29 个特征中,但是代码不起作用,我收到了以下消息:
File "sklearn/neighbors/binary_tree.pxi", line 1294, in sklearn.neighbors.kd_tree.BinaryTree.query (sklearn/neighbors/kd_tree.c:11337)
ValueError: query data dimension must match training data dimension
ValueError:查询数据维度必须匹配训练数据维度
请注意,我在一类支持向量机上应用了相同的方法,并且得到了结果。
#--- import required libraries ---#
import csv
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.neighbors import LocalOutlierFactor
#--- read csv file ---#
with open('test.csv', 'r') as f:
reader = csv.reader(f)
csv_values = list(reader)
#--- convert data type from string to float ---#
def read_lines():
with open('test.csv', 'rU') as data:
reader = csv.reader(data)
for row in reader:
yield [ float(i) for i in row ]
#--- values for meshgrid ---#
xx, yy= np.meshgrid(np.linspace(-5, 5, 100), np.linspace(-5,5,100))
#--- Classify observations into normal and outliers ---#
X = []; X_train = []; X_test = []; X_outliers = []
for i in range(len(csv_values)):
if csv_values[i][-1] == '0':
X.append(csv_values[i][:-1])
else:
X_outliers.append(csv_values[i][:-1])
#--- convert lists to arrays ---#
X=np.array(X)
X_outliers1= np.array(X_outliers)
#--- figure for all 29 plots ---#
fig=plt.figure(1)
for i in range(27):
#--- select 2 columns each time ---#
X=X[:,i:i+2]
X_outliers= X_outliers1[:,i:i+2]
#--- classification ---#
clf = LocalOutlierFactor(n_neighbors=20)
y_pred = clf.fit_predict(X)
y_pred_outliers = y_pred[998:]
Z = clf._decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
#--- plot settings ---#
fig.add_subplot(9,3,i+1)
fig.set_figheight(20)
fig.set_figwidth(16)
plt.subplots_adjust(wspace=0, hspace=0)
plt.title("Local Outlier Factor (LOF)")
plt.contourf(xx, yy, Z, cmap=plt.cm.Blues_r)
a = plt.scatter(X[:998, 0], X[:998, 1], c='white')
b = plt.scatter(X[998:, 0], X[998:, 1], c='red')
plt.axis('tight')
plt.xlim((-5, 5))
plt.ylim((-5, 5))
plt.legend([a, b],
["normal observations",
"abnormal observations"],
loc="upper left")
plt.savefig('test.png')
csv 文件包含 1000 行(数据点)和 29 列(特征),最后一列用于对正常值与异常值进行分类。在这个文件中,998 个正常数据点和 2 个异常值。