6

我有一个包含以下列的数据文件

'customer', 'calibrat' - 校准样本 = 1;验证样本 = 0;'流失','churndep','收入','mou',

数据文件包含大约 40000 行,其中 20000 的 calibrat 值为 1。我想将此数据拆分为

X1 = data.loc[:, data.columns != 'churn']
y1 = data.loc[:, data.columns == 'churn']
from imblearn.over_sampling import SMOTE
os = SMOTE(random_state=0)
X1_train, X1_test, y1_train, y1_test = train_test_split(X1, y1, test_size=0.3, random_state=0)

我想要的是在我的 X1_train 中应该有校准数据 calibrat = 1 并且在 X1_test 中应该有所有数据来验证 calibrat = 0

4

1 回答 1

4

sklearn.model_selection除了 .还有其他几个选项train_test_split。其中之一旨在解决您的要求。在这种情况下,您可以使用GroupShuffleSplit,如文档中所述,它提供随机训练/测试索引来根据第三方提供的组拆分数据。这在您进行交叉验证时很有用,并且您希望多次拆分验证训练,确保集合按group字段拆分。对于这些情况,您也有GroupKFold非常有用的。

因此,调整您的示例,这就是您可以做的。

比如说你有:

from sklearn.model_selection import GroupShuffleSplit

cols = ['customer', 'calibrat', 'churn', 'churndep', 'revenue', 'mou',]
X = pd.DataFrame(np.random.rand(10, 6), columns=cols)
X['calibrat'] = np.random.choice([0,1], size=10)

print(X)

   customer  calibrat     churn  churndep   revenue       mou
0  0.523571         1  0.394896  0.933637  0.232630  0.103486
1  0.456720         1  0.850961  0.183556  0.885724  0.993898
2  0.411568         1  0.003360  0.774391  0.822560  0.840763
3  0.148390         0  0.115748  0.089891  0.842580  0.565432
4  0.505548         0  0.370198  0.566005  0.498009  0.601986
5  0.527433         0  0.550194  0.991227  0.516154  0.283175
6  0.983699         0  0.514049  0.958328  0.005034  0.050860
7  0.923172         0  0.531747  0.026763  0.450077  0.961465
8  0.344771         1  0.332537  0.046829  0.047598  0.324098
9  0.195655         0  0.903370  0.399686  0.170009  0.578925

y = X.pop('churn')

您现在可以实例化GroupShuffleSplit,并按照您train_test_split的方式执行 ,唯一的区别是指定一个group列,该列将用于拆分Xy因此根据组值拆分组:

gs = GroupShuffleSplit(n_splits=2, train_size=.7, random_state=42)

如前所述,当您想要分成多个组时,这更方便,通常用于交叉验证目的。如问题中所述,这只是如何进行两次拆分的示例:

train_ix, test_ix = next(gs.split(X, y, groups=X.calibrat))

X_train = X.loc[train_ix]
y_train = y.loc[train_ix]

X_test = X.loc[test_ix]
y_test = y.loc[test_ix]

给予:

print(X_train)

   customer  calibrat  churndep   revenue       mou
3  0.148390         0  0.089891  0.842580  0.565432
4  0.505548         0  0.566005  0.498009  0.601986
5  0.527433         0  0.991227  0.516154  0.283175
6  0.983699         0  0.958328  0.005034  0.050860
7  0.923172         0  0.026763  0.450077  0.961465
9  0.195655         0  0.399686  0.170009  0.578925

print(X_test)

   customer  calibrat  churndep   revenue       mou
0  0.523571         1  0.933637  0.232630  0.103486
1  0.456720         1  0.183556  0.885724  0.993898
2  0.411568         1  0.774391  0.822560  0.840763
8  0.344771         1  0.046829  0.047598  0.324098
于 2020-04-09T07:33:03.170 回答