我正在解决水果数据集上的对象检测问题:https ://yadi.sk/d/UPwQB7OZrB48qQ 。我得到了我的数据集类的代码:
class2tag = {"apple": 1, "orange": 2, "banana": 3}
class FruitDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.images = []
self.annotations = []
self.transform = transform
for annotation in glob.glob(data_dir + "/*xml"):
image_fname = os.path.splitext(annotation)[0] + ".jpg"
self.images.append(cv2.cvtColor(cv2.imread(image_fname), cv2.COLOR_BGR2RGB))
with open(annotation) as f:
annotation_dict = xmltodict.parse(f.read())
bboxes = []
labels = []
objects = annotation_dict["annotation"]["object"]
if not isinstance(objects, list):
objects = [objects]
for obj in objects:
bndbox = obj["bndbox"]
bbox = [bndbox["xmin"], bndbox["ymin"], bndbox["xmax"], bndbox["ymax"]]
bbox = list(map(int, bbox))
bboxes.append(torch.tensor(bbox))
labels.append(class2tag[obj["name"]])
self.annotations.append(
{"boxes": torch.stack(bboxes).float(), "labels": torch.tensor(labels)}
)
def __getitem__(self, i):
if self.transform:
# the following code is correct if you use albumentations
# if you use torchvision transforms you have to modify it =)
res = self.transform(
image=self.images[i],
bboxes=self.annotations[i]["boxes"],
labels=self.annotations[i]["labels"],
)
return res["image"], {
"boxes": torch.tensor(res["bboxes"]),
"labels": torch.tensor(res["labels"]),
}
else:
return self.images[i], self.annotations[i]
def __len__(self):
return len(self.images)
我在 Google Colab 中做我的项目,所以我已经安装了 Google Drive 并解压缩了存档。
from google.colab import drive
drive.mount('/content/drive')
然后我使用allementations做了一些扩充:
train_transform = A.Compose([
A.Flip(p=0.25),
A.RGBShift(p=0.2),
], bbox_params=A.BboxParams(format='coco'))
val_transform = A.Compose([], bbox_params=A.BboxParams(format='coco'))
train_dataset = FruitDataset("./train_zip/train", transform=train_transform)
val_dataset = FruitDataset("./test_zip/test", transform=val_transform)
但是,当我运行时len(train_dataset)
,我得到的值为 0。所以,我无法理解为什么我的数据集大小为 0。我也无法理解问题出在哪里。非常感谢任何可能的建议。