我正在尝试SparseCategoricalAccuracy
在 tf 2.0 中创建一个可以通过compile(metrics=[masked_accuracy_fn()]
.
该函数如下所示:
def get_masked_acc_metric_fn(ignore_label=-1):
"""Gets the masked accuracy function."""
def masked_acc_fn(y_true, y_pred):
"""Masked accuracy."""
y_true = tf.squeeze(y_true)
# Create mask for time steps we don't care about
mask = tf.not_equal(y_true, ignore_label)
masked_acc = tf.keras.metrics.SparseCategoricalAccuracy(
'test_masked_accuracy', dtype=tf.float32)(y_true, y_pred, sample_weight=mask)
return masked_acc
return masked_acc_fn
这适用于 Eager 模式。但是,在图形模式下运行时,出现错误:
ValueError: tf.function-decorated function tried to create variables on non-first call