我正在使用albumentations 将转换应用于Pytorch 模型,但出现此错误,我没有得到任何关于此错误的线索。我只知道这是由于正在应用的转换而发生的,但不确定这有什么问题。

ValueError: Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "<ipython-input-23-119ea6bc360e>", line 24, in __getitem__
    image = self.transform(image)
  File "/opt/conda/lib/python3.6/site-packages/albumentations/core/composition.py", line 164, in __call__
    need_to_run = force_apply or random.random() < self.p
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

这是代码片段。数据加载器 getitem ( ) 方法:

        image = cv2.imread(p_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = crop_image_from_gray(image)
        image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
        image = cv2.addWeighted ( image,4, cv2.GaussianBlur( image , (0,0) , 10) ,-4 ,128)
        image = self.transform(image)


val_transform = albumentations.Compose([
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],


valset       = MyDataset(val_df, transform = val_transform)

1 回答 1



from PIL import Image
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize
from albumentations.pytorch import ToTensor

class AlbumentationsDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]

        # Read an image with OpenCV
        image = cv2.imread(file_path)

        # By default OpenCV uses BGR color space for color images,
        # so we need to convert the image to RGB color space.
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image = crop_image_from_gray(image)
        image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
        image = cv2.addWeighted ( image,4, cv2.GaussianBlur( image , (0,0) , 10) ,-4 ,128)

        image = Img.fromarray(image, mode='RGB')  
        if self.transform:
            augmented = self.transform(image=np.array(image))
            image = augmented['image']

        image = np.transpose(image, (2, 0, 1))

        return image, label

albumentations_transform = Compose([

        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],

albumentations_dataset = AlbumentationsDataset(
    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
    labels=[1, 2, 3],

test_loader = DataLoader(dataset = albumentations_dataset, batch_size=4, drop_last=False, shuffle=False).
于 2019-08-30T09:34:46.907 回答