1

试图将显着性图放到图像上并制作一个新的数据集

trainloader = utilsxai.load_data_cifar10(batch_size=1,test=False)
testloader =  utilsxai.load_data_cifar10(batch_size=128, test=True)

这个 load_cifar10 是 torchvision

data = trainloader.dataset.data 

trainloader.dataset.data = (data * sal_maps_hf).reshape(data.shape)

带有 (50000,32,32,3) 的 sal_maps_hf 形状和带有 (50000,32,32,3)的 trainloader
形状

但是当我运行这个

for idx,img in enumerate(trainloader):

-------------------------------------------------- ------------------------- KeyError Traceback(最近一次调用最后)~/venv/lib/python3.7/site-packages/PIL/Image .py in fromarray(obj, mode) 2644 typekey = (1, 1) + shape[2:], arr["typestr"] -> 2645 mode, rawmode = _fromarray_typemap[typekey] 2646 除了 KeyError:

KeyError: ((1, 1, 3), '

在处理上述异常的过程中,又出现了一个异常:

----> 1 show_images(trainloader) 中的 TypeError Traceback (最近一次调用最后一次)

在 show_images(trainloader) 1 def show_images(trainloader): ----> 2 for idx,(img,target) in enumerate(trainloader): 3 img = img.squeeze() 4 #pritn(img) 5 img = torch .张量(img)

~/venv/lib/python3.7/site-packages/torch/utils/data/dataloader.py in next (self) 344 def next (self): 345 index = self._next_index() # may raise StopIteration --> 346 data = self._dataset_fetcher.fetch(index) # 可能会引发 StopIteration 347 if self._pin_memory: 348 data = _utils.pin_memory.pin_memory(data)

~/venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py​​ in fetch(self, possible_batched_index) 42 def fetch(self, possible_batched_index): 43 if self.auto_collat​​ion: --- > 44 data = [self.dataset[idx] for idx in possible_batched_index] 45 else: 46 data = self.dataset[possibly_batched_index]

~/venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py​​ in (.0) 42 def fetch(self, possible_batched_index): 43 if self.auto_collat​​ion: ---> 44 data = [self.dataset[idx] for idx in possible_batched_index] 45 else: 46 data = self.dataset[possibly_batched_index]

~/venv/lib/python3.7/site-packages/torchvision/datasets/cifar.py in getitem (self, index) 120 # 这样做是为了与所有其他数据集一致 121 # 返回 PIL 图像 -- > 122 img = Image.fromarray(img) 123 124 如果 self.transform 不是 None:

~/venv/lib/python3.7/site-packages/PIL/Image.py in fromarray(obj, mode) 2645 模式,rawmode = _fromarray_typemap[typekey] 2646 除了 KeyError: -> 2647 raise TypeError("Cannot handle this data类型") 2648 其他:2649 原始模式 = 模式

TypeError:无法处理此数据类型

trainloader.dataset.__getitem__

getitem of Dataset CIFAR10 数据点数:50000 根位置:/mnt/3CE35B99003D727B/input/pytorch/data 拆分:训练 StandardTransform 变换:Compose( Resize(size=32, interpolation=PIL.Image.BILINEAR) ToTensor() )

4

1 回答 1

1

sal_maps_hf的不是np.uint8

根据问题和评论中的部分信息,我猜您的掩码是dtype np.float(或类似的),并且通过将data * sal_maps_hf您的数据相乘,将其转换为dtype其他类型,np.uint8而不是稍后PIL.Image抛出异常。

尝试:

trainloader.dataset.data = (data * sal_maps_hf).reshape(data.shape).astype(np.uint8)
于 2019-12-30T12:57:18.010 回答