1

语境

使用带有tf.Estimator接口的 TF 1.15 来训练和评估模型。尝试编写自定义 TF 指标,tf.keras.metric.Metric用于此目的。

问题

我编写了一个自定义指标并将其包含在eval_metrics_ops(下面的示例)中。如果我用指标定义一个估计器,我会收到以下错误。

ValueError: Please call update_state(...) on the "<metric_name>" metric 

错误的措辞看起来很清楚(我必须调用update_state()),但我不确定我在哪里调用update_state()指标(不确定我是否应该调用)。不是一个最小的例子,但这是我写的一个演示指标。

class MyMetric(tf.keras.metrics.Metric):
    def __init__(self, name="my_metric", **kwargs):
        super(MyMetric, self).__init__(name=name, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.true_samples = tf.reduce_sum(y_true)
    
    def result(self):
        return self.true_samples

创建一个dict指标名称是键,指标实例是值。是它提到如何创建dictfor 的地方eval_metrics_ops

metrics_ops = {"my_metric": MyMetric()}`. # The TensorFlow 1.15 documentation does not say we have to call `update_state(....) anywhere.`
estimator_spec = tf.estimator.EstimatorSpec(mode, model.loss, eval_metric_ops=metrics_ops)

知道如何摆脱该错误吗?

4

0 回答 0