解决方案是首先创建一个自定义指标:
import torch
from ignite.metrics import Metric
from sklearn.metrics import f1_score
class F1Score(Metric):
def __init__(self, *args, **kwargs):
self.f1 = 0
self.count = 0
super().__init__(*args, **kwargs)
def update(self, output):
y_pred, y = output[0].detach(), output[1].detach()
_, predicted = torch.max(y_pred, 1)
f = f1_score(y.cpu(), predicted.cpu(), average='micro')
self.f1 += f
self.count += 1
def reset(self):
self.f1 = 0
self.count = 0
super(F1Score, self).reset()
def compute(self):
return self.f1 / self.count
然后你可以使用它create_supervised_evaluator
或create_supervised_trainer
作为:
import logging
import torch
from ignite.engine import Events
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Accuracy, Fbeta
from ignite.metrics.precision import Precision
from ignite.metrics.recall import Recall
from metrics.f1score import F1Score
def inference(
cfg,
model,
val_loader
):
device = cfg.MODEL.DEVICE
logger = logging.getLogger("template_model.inference")
logger.info("Start inferencing")
precision = Precision(average=False)
recall = Recall(average=False)
F1 = Fbeta(beta=1.0, average=False, precision=precision, recall=recall)
metrics = {'accuracy': Accuracy(),
'precision': precision,
'recall': recall,
'custom': F1Score(),
'f1': F1}
evaluator = create_supervised_evaluator(model,
metrics=metrics,
device=device)
# adding handlers using `evaluator.on` decorator API
@evaluator.on(Events.EPOCH_COMPLETED)
def print_validation_results(engine):
metrics = evaluator.state.metrics
metrics = evaluator.state.metrics
_avg_accuracy = metrics['accuracy']
_precision = metrics['precision']
_precision = torch.mean(_precision)
_recall = metrics['recall']
_recall = torch.mean(_recall)
_f1 = metrics['f1']
_f1 = torch.mean(_f1)
_custom = metrics['custom']
logger.info(
"Test Results - Epoch: {} Avg accuracy: {:.3f}, precision: {:.3f}, recall: {:.3f}, f1 score: {:.3f}, custom: {:.2f}".format(
engine.state.epoch, _avg_accuracy, _precision, _recall, _f1, _custom))
evaluator.run(val_loader)
结果是:
Test Results - Epoch: 1 Avg accuracy: 0.758, precision: 0.776, recall: 0.766, f1 score: 0.759, custom: 0.76