0

我正在尝试在不平衡的图像数据集(1:250 类图像,0:4000 种图像)上使用 Pytorch 训练/验证 CNN,而现在,我只在我的训练集上尝试了增强(感谢@jodag) . 但是,我的模型仍在学习使用更多图像来支持类。

我想找到弥补我不平衡数据集的方法。

我考虑过使用不平衡数据采样器 ( https://github.com/ufoym/imbalanced-dataset-sampler ) 使用过采样/欠采样,但我已经使用采样器为我的 5 倍验证选择索引。有没有办法可以使用下面的代码实现交叉验证并添加这个采样器?同样,有没有办法让一个标签比另一个更频繁地增加?根据这些问题,是否有其他更简单的方法可以解决我尚未研究的不平衡数据集?

这是我到目前为止的一个例子

total_set = datasets.ImageFolder(PATH)
KF_splits = KFold(n_splits= 5, shuffle = True, random_state = 42)

for train_idx, valid_idx in KF_splits.split(total_set):
    #sampler to get indices for cross validation
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    #Use a wrapper to apply augmentation only to training set
    #These are dataloaders that pull images from the same folder but sort into validation and training sets
    #Though transforms augment only the training set, it doesn't address
    #the underlying issue of a heavily unbalanced dataset

    train_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['train']),
        batch_size=32, sampler=ImbalancedDatasetSampler(total_set))
    valid_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['val']),
        batch_size=32)

    print("Fold:" + str(i))

    for epoch in range(epochs):
        #Train/validate model below

`

感谢您的时间和帮助!

4

0 回答 0