我需要使用分层 kfold(不平衡的多类任务)交叉验证 keras 模型。是否可以在 (folds = list(StratifiedKFold(k, shuffle=True, random_state=1).split(x_train, y_train)) 中将 x_train/y_train 与 imagedatagenerator (flow_from_directory) 一起使用?在 Kaggle ( https:// /www.kaggle.com/stefanie04736/simple-keras-model-with-k-fold-cross-validation?select=train.json.7z),但是 x_train, y_train = next(train_generator) 没有映射数据和标签正确。任何帮助将不胜感激!
train_generator = ImageDataGenerator(rescale=1./255).flow_from_directory(
directory=train_path,
target_size=input_img[:-1],
color_mode="rgb",
batch_size=BATCH_SIZE,
classes=target_names,
class_mode="input")
test_generator = ImageDataGenerator(rescale=1./255).flow_from_directory(
directory=test_path,
target_size=input_img[:-1],
color_mode="rgb",
batch_size=BATCH_SIZE,
classes=target_names,
class_mode="input",
shuffle=False)
test_labels = test_generator.classes
#Instantiate to load data and generate k stratified folds
k = 5
def load_data_kfold(k):
#For StratifiedKFold, labels (y_train) must be 1-D array of labels (Cannot be one-hot)
x_train, y_train = next(train_generator)
print(x_train)
print(y_train)
folds = list(StratifiedKFold(k, shuffle=True, random_state=1).split(x_train, y_train))
return folds, x_train, y_train
folds, x_train, y_train = load_data_kfold(k)
Traceback (most recent call last):
File "C:/Users/LaRoche Lab/PycharmProjects/pythonProject2/R.py", line 122, in <module>
folds, x_train, y_train = load_data_kfold(k)
File "C:/Users/LaRoche Lab/PycharmProjects/pythonProject2/R.py", line 118, in load_data_kfold
folds = list(StratifiedKFold(k, shuffle=True, random_state=1).split(x_train, y_train))
File "C:\Users\LaRoche Lab\Anaconda3\envs\tensorflow\lib\site-
packages\sklearn\model_selection\_split.py", line 735, in split
y = check_array(y, ensure_2d=False, dtype=None)
File "C:\Users\LaRoche Lab\Anaconda3\envs\tensorflow\lib\site-packages\sklearn\utils\validation.py",
line 73, in inner_f return f(**kwargs)File "C:\Users\LaRoche Lab\Anaconda3\envs\tensorflow\lib\site-
packages\sklearn\utils\validation.py", line 642, in check_array
% (array.ndim, estimator_name))
ValueError: Found array with dim 4. Estimator expected <= 2.