0

我正在尝试SparseCategoricalAccuracy在 tf 2.0 中创建一个可以通过compile(metrics=[masked_accuracy_fn()].

该函数如下所示:

def get_masked_acc_metric_fn(ignore_label=-1):
    """Gets the masked accuracy function."""
    def masked_acc_fn(y_true, y_pred):
        """Masked accuracy."""
        y_true = tf.squeeze(y_true)
        # Create mask for time steps we don't care about
        mask = tf.not_equal(y_true, ignore_label)
        masked_acc = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_masked_accuracy', dtype=tf.float32)(y_true, y_pred, sample_weight=mask)
        return masked_acc

    return masked_acc_fn

这适用于 Eager 模式。但是,在图形模式下运行时,出现错误:

ValueError: tf.function-decorated function tried to create variables on non-first call
4

1 回答 1

0

这似乎是一种临时解决方法:

class MaskedSparseCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):
    def __init__(self, name="masked_sparse_categorical_accuracy", dtype=None):
        super(MaskedSparseCategoricalAccuracy, self).__init__(name, dtype=dtype)

    def update_state(self, y_true, y_pred, ignore_label=-1):
        sample_weight = tf.not_equal(y_true, ignore_label)
        super(MaskedSparseCategoricalAccuracy, self).update_state(y_true, y_pred, sample_weight)
于 2019-12-15T18:20:40.857 回答