我定义了一个函数来打印 K 折叠分数并确定最佳 c,但是在下面的代码中的这一行中出现错误(我也在代码中突出显示了该行):
best_c = results_table.loc[results_table['平均召回分数'].idxmax()]['C_parameter']
# Ad hoc function to print k_fold scores
def printing_Kfold_scores(x_train_data, y_train_data):
fold = KFold(n_splits=5, shuffle=False)
# Different c parameters (hyperparameter for K-fold)
c_param_range = [0.01, 0.1, 1, 10, 100]
results_table = pd.DataFrame(index = range(len(c_param_range), 2), columns =
['C_parameter_r', 'Mean recall score'])
results_table['C_parameter'] = c_param_range
# the. k-fold will gove 2 lists: train and test indices
j=0
for c_param in c_param_range:
print('------------------------------------------')
print('C parameter: ', c_param)
print('------------------------------------------')
print('')
recall_accs=[]
for iteration, indices in enumerate(fold.split(x_train_data), start=1):
#Calling the logistic regression model
lr = LogisticRegression(C = c_param, penalty = 'l1', solver='liblinear')
lr.fit(x_train_data.iloc[indices[0],:],
y_train_data.iloc[indices[0],:].values.ravel())
y_pred_undersample = lr.predict(x_train_data.iloc[indices[1],:].values)
recall_acc = recall_score(y_train_data.iloc[indices[1], :].values,
y_pred_undersample)
recall_accs.append(recall_acc)
print('Iteration ', iteration, ': recall_score = ', recall_acc)
results_table.ix[j, 'Mean recall score'] = np.mean(recall_accs)
j+=1
print('')
print('Mean recall score: ', np.mean(recall_accs))
print('')
########################################################################################
best_c = results_table.loc[results_table['Mean recall score'].idxmax()['C_parameter']
########################################################################################
#This is where I get the error#
#Checking best c parameter
print('********************************************************************************')
print('Best c parameter : ', best_c)
print('********************************************************************************')
return best_c
我得到的错误是
TypeError: reduction operation 'argmax' not allowed for this dtype