我正在尝试使用他们自己的教程从他们的页面微调 Pytorch 模型。我在来自 Kaggle 的数据集 StaVer 上进行了尝试:rtatman/stamp-verification-staver-dataset。
他们代码的唯一变化是数据集的路径。训练模型部分出现的错误:
# let's train it for 10 epochs
from torch.optim.lr_scheduler import StepLR
num_epochs = 20
for epoch in range(num_epochs):
# train for one epoch, printing every 10 iterations
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=1)
# update the learning rate
lr_scheduler.step()
# evaluate on the test dataset
evaluate(model, data_loader_test, device=device)
.
只需运行以下代码即可复制相同的错误:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
dataset = StaverDataset("", get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=utils.collate_fn
)
a = iter(data_loader)
for i in range(len(a)):
images, targets = next(a)
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images, targets)
奇怪的是它在 image_382 之前工作正常。如果我删除所有以前的图像并在最后大约运行模型。50张图片也可以。似乎问题与图像无关,而与图像数量有关。通过模型()运行的图像数量。
这是错误:
scans/scans/stampDS-00380.png
scans/scans/stampDS-00381.png
scans/scans/stampDS-00382.png
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-40-6ccb0398c1aa> in <module>()
4 images = list(image for image in images)
5 targets = [{k: v for k, v in t.items()} for t in targets]
----> 6 output = model(images, targets)
7 model.eval()
6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor)
3710 return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
3711 if input.dim() == 4 and mode == "nearest":
-> 3712 return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
3713 if input.dim() == 5 and mode == "nearest":
3714 return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
RuntimeError: Input and output sizes should be greater than 0, but got input (H: 1605, W: 2) output (H: 799, W: 0)
有谁知道错误与什么有关?先感谢您。迈克尔