0

谷歌搜索这会让你无处可去,所以我决定通过将其发布为可搜索的问题来帮助未来的我和其他人。


def __init__():
    ...
    self.val_acc = pl.metrics.Accuracy()

def validation_step(self, batch, batch_index):
    ...
    self.val_acc.update(log_probs, label_batch)

ValueError: preds and target must have same number of dimensions, or one additional dimension for preds

log_probs.shape == (16, 4)和为label_batch.shape == (16, 4)

有什么问题?

4

1 回答 1

0

pl.metrics.Accuracy()需要一批dtype=torch.long标签,而不是一次性编码标签。

所以应该喂

self.val_acc.update(log_probs, torch.argmax(label_batch.squeeze(), dim=1))


这与torch.nn.CrossEntropyLoss

于 2021-03-04T11:35:21.820 回答