我正在使用已经在另一个上下文中运行的多标签分类器。
数据分类器包含来自图表的节点访问的数据,可在此处获得: https ://drive.google.com/file/d/1xD2dq4UL0UqQsuvEWgjSnFLBpjC5xlVL/view?usp=sharing
使用以下命令保存 csv:
data_classifier2.to_csv('data_classifier2.csv', encoding='utf-8', sep=',')
代码:
size_test = 0.2
splits = 10
train = []
test = []
X = data_classifier2
rs = ShuffleSplit(n_splits=splits, test_size= size_test, random_state=52)
rs.get_n_splits(X)
for train_index, test_index in rs.split(X):
#print("%s %s" % (train_index, test_index))
train.append(train_index)
test.append(test_index)
for i in range(0, len(train)):
size_features = len(X.columns)
X_train = data_classifier2.iloc[train[i],65:size_features]
y_train = data_classifier2.iloc[train[i],6:64]
X_test = data_classifier2.iloc[test[i],65:size_features]
y_test = data_classifier2.iloc[test[i],6:64]
categories = y_test.columns.values.tolist()
ids = y_test.index
classifier_setup = build_model(test_type)
clf = classifier_setup
clf.fit(X_train,y_train)
predict = clf.predict(X_test).toarray()
probability = clf.predict_proba(X_test).toarray()
predictions = pd.DataFrame(predict, index=ids, columns=categories) # with header
probabilities = pd.DataFrame(probability, index=ids, columns=categories) # with header
X_train、X_test、y_train 和 y_test 的形状似乎还可以。我想这可能是任何标签概率都缺乏价值,但我不确定!无论如何,如果这是真的,我们怎么能避免呢?
错误信息:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-74-17328f2b4b23> in <module>
----> 1 probability = clf.predict_proba(X_test).toarray()
/opt/anaconda3/lib/python3.7/site-packages/skmultilearn/problem_transform/br.py in predict_proba(self, X)
211 classifier.predict_proba(
212 self._ensure_input_format(X))
--> 213 )[:, 1] # probability that label is assigned
214
215 return result
/opt/anaconda3/lib/python3.7/site-packages/scipy/sparse/_index.py in __getitem__(self, key)
33 """
34 def __getitem__(self, key):
---> 35 row, col = self._validate_indices(key)
36 # Dispatch to specialized methods.
37 if isinstance(row, INT_TYPES):
/opt/anaconda3/lib/python3.7/site-packages/scipy/sparse/_index.py in _validate_indices(self, key)
142 col = int(col)
143 if col < -N or col >= N:
--> 144 raise IndexError('column index (%d) out of range' % col)
145 if col < 0:
146 col += N
IndexError: column index (1) out of range