当我运行 Wide Resnet 代码时,出现运行时错误。RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28] 我尝试了几种在线可用的解决方案,但都没有解决,它们都显示了其他问题。我不知道如何解决它。所有相关的运行时错误都显示在代码上。
elif args.data == 'kmnist':
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
if args.data_augmentation:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
else:
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# If following portion, it would be another runtime error
# RuntimeError: Given groups=1, weight of size 16 3 3 3, expected
#input[128, 1, 28, 28] to have 3 channels, but got 1 channels instead
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# If I tried following portion, it would be
# RuntimeError: output with shape [1, 28, 28] doesn't match the
# broadcast shape [3, 28, 28]
# transform_test = transforms.Compose([
# transforms.ToTensor(),
# normalize
# ])
# If I tried following portion of the code, I received
# AttributeError: Can't pickle local object 'get_data_loaders.<locals>.<lambda>'
# transform_test = transforms.Compose([
# transforms.ToTensor(),
# transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ])
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
KMNISTRandomLabels(root='./kmnistdata', train=True, download=True,
transform=transform_train, num_classes=args.num_classes,
corrupt_prob=args.label_corrupt_prob),
batch_size=args.batch_size, shuffle=shuffle_train, **kwargs)
val_loader = torch.utils.data.DataLoader(
KMNISTRandomLabels(root='./kmnistdata', train=False,
transform=transform_test, num_classes=args.num_classes,
corrupt_prob=args.label_corrupt_prob),
batch_size=args.batch_size, shuffle=False, **kwargs)
return train_loader, val_loader
"""
Fashion-MNIST dataset, with support for random labels
"""
import numpy as np
import torch
import torchvision.datasets as datasets
class FashionMNISTRandomLabels(datasets.FashionMNIST):
"""Fashion-MNIST dataset, with support for randomly corrupt labels.
Params
------
corrupt_prob: float
Default 0.0. The probability of a label being replaced with
random label.
num_classes: int
Default 10. The number of classes in the dataset.
"""
def __init__(self, corrupt_prob=0.0, num_classes=10, **kwargs):
super(FashionMNISTRandomLabels, self).__init__(**kwargs)
self.n_classes = num_classes
if corrupt_prob > 0:
self.corrupt_labels(corrupt_prob)
def corrupt_labels(self, corrupt_prob):
labels = np.array(self.train_labels if self.train else self.test_labels)
np.random.seed(12345)
mask = np.random.rand(len(labels)) <= corrupt_prob
rnd_labels = np.random.choice(self.n_classes, mask.sum())
labels[mask] = rnd_labels
# we need to explicitly cast the labels from npy.int64 to
# builtin int type, otherwise pytorch will fail...
labels = [int(x) for x in labels]
if self.train:
self.train_labels = labels
else:
self.test_labels = labels