我正在尝试在不平衡的图像数据集(1:250 类图像,0:4000 种图像)上使用 Pytorch 训练/验证 CNN,而现在,我只在我的训练集上尝试了增强(感谢@jodag) . 但是,我的模型仍在学习使用更多图像来支持类。
我想找到弥补我不平衡数据集的方法。
我考虑过使用不平衡数据采样器 ( https://github.com/ufoym/imbalanced-dataset-sampler ) 使用过采样/欠采样,但我已经使用采样器为我的 5 倍验证选择索引。有没有办法可以使用下面的代码实现交叉验证并添加这个采样器?同样,有没有办法让一个标签比另一个更频繁地增加?根据这些问题,是否有其他更简单的方法可以解决我尚未研究的不平衡数据集?
这是我到目前为止的一个例子
total_set = datasets.ImageFolder(PATH)
KF_splits = KFold(n_splits= 5, shuffle = True, random_state = 42)
for train_idx, valid_idx in KF_splits.split(total_set):
#sampler to get indices for cross validation
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
#Use a wrapper to apply augmentation only to training set
#These are dataloaders that pull images from the same folder but sort into validation and training sets
#Though transforms augment only the training set, it doesn't address
#the underlying issue of a heavily unbalanced dataset
train_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['train']),
batch_size=32, sampler=ImbalancedDatasetSampler(total_set))
valid_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['val']),
batch_size=32)
print("Fold:" + str(i))
for epoch in range(epochs):
#Train/validate model below
`
感谢您的时间和帮助!