我目前有一个弱监督项目,我需要在数据集前面放置一个“掩码”。我现在的问题是我不知道该怎么做。让我用一些代码和图像进一步解释。
我正在使用必须以这种方式编辑的 MNIST 数据集。如您所见,中间的正方形被切掉了。下面的代码用于使用 for 循环编辑 MNIST。
for i in range(int(image_size/2-5),int(image_size/2+3)):
for j in range(int(image_size/2-5),int(image_size/2+3)):
image[i][j] = 0
但是,我目前不确定如何在数据加载器转换中使用它。数据加载器和转换的代码如下所示:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=True, transform=transform, download=True
)
test_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=False, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=32, shuffle=False, num_workers=4
)
def imshow(img):
#img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
那么有没有一种直接的方法可以将转换应用于 中的完整数据集torchvision.transforms.Compose
?