2

我正在关注Kaggle 上的内核并遇到了这段代码。

#Validation function
n_folds = 5

def rmsle_cv(model):
    kf = KFold(n_folds, shuffle=True, random_state=42).get_n_splits(train.values)
    rmse= np.sqrt(-cross_val_score(model, train.values, y_train, scoring="neg_mean_squared_error", cv = kf))
    return(rmse)

我了解 KFold 的用途和用途以及“cross_val_score”中使用的事实。我不明白为什么要使用“get_n_split”?据我所知,它返回用于交叉验证的迭代次数,即在这种情况下返回值 5。当然对于这一行:

rmse= np.sqrt(-cross_val_score(model, train.values, y_train, scoring="neg_mean_squared_error", cv = kf))

简历 = 5?这对我来说没有任何意义。如果 get_n_splits 返回一个整数,为什么还要使用它?我认为KFold 返回一个类,而get_n_splits返回一个整数。

任何人都可以澄清我的理解吗?

4

1 回答 1

4

我认为 KFold 返回一个类,而get_n_splits返回一个整数。

当然,KFold是一个类,其中一个类方法是get_n_splits,它返回一个整数;您显示的kf变量

kf = KFold(n_folds, shuffle=True, random_state=42).get_n_splits(train.values)

不是KFold类对象,是KFold().get_n_splits() 方法的结果,确实是整数。事实上,如果您查看文档get_n_splits()甚至不需要任何参数(它们实际上被忽略,并且仅出于与其他类和方法的兼容性原因而存在)。

至于该get_n_splits方法的受质疑实用性,能够查询此类对象以获取其参数设置(相反)绝不是一个坏主意;想象这样一种情况,您有多个不同的KFold对象,并且您需要在程序流程中以编程方式获取它们各自的 CV 折叠数。

于 2020-04-25T18:19:19.180 回答