0

有人可以帮我理解这个功能的作用吗?

我了解行打印,但在那之后我有点迷路了。从train_data.

def stratifiedShuffleSplit_data(X, y):
    sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
    for train_index, test_index in sss.split(X, y):
        print("len(TRAIN):", len(train_index), "len(TEST):", len(test_index))
        print("TRAIN:", train_index, "TEST:", test_index)

        train_data = [df.loc[ind] for ind in train_index]
        test_data = [df.loc[ind] for ind in test_index]
        save_datarows(train_data, datafile+".train")
        save_datarows(test_data, datafile+".test")
4

1 回答 1

0

假设您使用的是 Panda 包,

 pd.DataFrame.loc 

是一种基于位置的索引器 - 这是一个过于简化的版本。我将发布一些资源,可以帮助您更好地理解它。

train_data = [df.loc[ind] for ind in train_index]

在这里,您基本上迭代列表 ind 并存储相应的值 train_data 对于 test_data 的情况类似

我假设 save_datarows 是一个自定义函数,用于将 train_data 存储到扩展名为 .train 的文件中

希望这可以帮助。

这是一个非常好的参考,可以进一步澄清:

在 python 中使用 .loc 进行选择

https://www.geeksforgeeks.org/python-pandas-dataframe-loc/

于 2019-12-10T17:36:50.773 回答