0

我正在尝试学习 PyTorch 并使用自定义数据集。代码信用 - https://github.com/vineeth2309/Custom-Dataset-and-Dataloader-in-Torch

但是,当我运行代码时,我得到一个“KeyError”。

import glob
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self):
        self.imgs_path = "Dog_Cat_Dataset/"
        file_list = glob.glob(self.imgs_path + "*")
        print(file_list)
        self.data = []
        for class_path in file_list:
            class_name = class_path.split("/")[-1]
            for img_path in glob.glob(class_path + "/*.jpeg"):
                self.data.append([img_path, class_name])
        print(self.data)
        self.class_map = {"dogs" : 0, "cats": 1}
        self.img_dim = (416, 416)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, class_name = self.data[idx]
        img = cv2.imread(img_path)
        img = cv2.resize(img, self.img_dim)
        class_id = self.class_map[class_name]
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.permute(2, 0, 1)
        class_id = torch.tensor([class_id])
        return img_tensor, class_id

if __name__ == "__main__":
    dataset = CustomDataset()
    print (dataset)
    data_loader = DataLoader(dataset, batch_size=4, shuffle=True)
    print (data_loader)
    for imgs, labels in data_loader:
        print("Batch of images has shape: ",imgs.shape)
        print("Batch of labels has shape: ", labels.shape)

堆栈跟踪 :

C:\Users\parag\anaconda3\envs\tf-gpu\python.exe C:/Users/parag/PycharmProjects/Custom-Dataset-and-Dataloader-in-Torch/Main.py
['Dog_Cat_Dataset\\cats', 'Dog_Cat_Dataset\\dogs']
[['Dog_Cat_Dataset\\cats\\1.jpeg', 'Dog_Cat_Dataset\\cats'], ['Dog_Cat_Dataset\\cats\\2.jpeg', 'Dog_Cat_Dataset\\cats'], ['Dog_Cat_Dataset\\cats\\3.jpeg', 'Dog_Cat_Dataset\\cats'], ['Dog_Cat_Dataset\\cats\\4.jpeg', 'Dog_Cat_Dataset\\cats'], ['Dog_Cat_Dataset\\cats\\5.jpeg', 'Dog_Cat_Dataset\\cats'], ['Dog_Cat_Dataset\\dogs\\1.jpeg', 'Dog_Cat_Dataset\\dogs'], ['Dog_Cat_Dataset\\dogs\\2.jpeg', 'Dog_Cat_Dataset\\dogs'], ['Dog_Cat_Dataset\\dogs\\3.jpeg', 'Dog_Cat_Dataset\\dogs'], ['Dog_Cat_Dataset\\dogs\\4.jpeg', 'Dog_Cat_Dataset\\dogs'], ['Dog_Cat_Dataset\\dogs\\5.jpeg', 'Dog_Cat_Dataset\\dogs']]
<__main__.CustomDataset object at 0x00000254C0568FD0>
<torch.utils.data.dataloader.DataLoader object at 0x00000254C250A7F0>
Traceback (most recent call last):
  File "C:\Users\parag\PycharmProjects\Custom-Dataset-and-Dataloader-in-Torch\Main.py", line 40, in <module>
    for imgs, labels in data_loader.dataset:
  File "C:\Users\parag\PycharmProjects\Custom-Dataset-and-Dataloader-in-Torch\Main.py", line 29, in __getitem__
    class_id = self.class_map[class_name]
KeyError: 'Dog_Cat_Dataset\\cats'

Process finished with exit code 1

[我的文件夹结构][1]

我已经尝试过,但无法解决错误。有人可以帮我吗?[1]:https ://i.stack.imgur.com/1giJw.png

4

1 回答 1

1

当它应该是反斜杠时,您使用了正斜杠:

    for class_path in file_list:
        class_name = class_path.split("\\")[-1]
        for img_path in glob.glob(class_path + "\*.jpeg"):
            self.data.append([img_path, class_name])

当示例来自 linux 时,我猜您正在 Windows 上运行。

于 2022-02-03T00:57:36.593 回答