0

下面的代码取自 huggingface 的教程

from datasets import load_metric

metric= load_metric("glue", "mrpc")
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

在循环内部for batch in eval_dataloader:,我如何知道该批次包含数据集中的哪些索引?

DataLoader 是使用较早创建的

eval_dataloader = DataLoader(
    tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)

请注意,它没有改组标志,因此可以使用批量大小手动计数,但是如何进行改组呢?创建数据集和数据加载器时是否可以将其作为批处理的字段?

4

0 回答 0