语境
使用带有tf.Estimator
接口的 TF 1.15 来训练和评估模型。尝试编写自定义 TF 指标,tf.keras.metric.Metric
用于此目的。
问题
我编写了一个自定义指标并将其包含在eval_metrics_ops
(下面的示例)中。如果我用指标定义一个估计器,我会收到以下错误。
ValueError: Please call update_state(...) on the "<metric_name>" metric
错误的措辞看起来很清楚(我必须调用update_state()
),但我不确定我在哪里调用update_state()
指标(不确定我是否应该调用)。不是一个最小的例子,但这是我写的一个演示指标。
class MyMetric(tf.keras.metrics.Metric):
def __init__(self, name="my_metric", **kwargs):
super(MyMetric, self).__init__(name=name, **kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
self.true_samples = tf.reduce_sum(y_true)
def result(self):
return self.true_samples
创建一个dict
指标名称是键,指标实例是值。这是它提到如何创建dict
for 的地方eval_metrics_ops
。
metrics_ops = {"my_metric": MyMetric()}`. # The TensorFlow 1.15 documentation does not say we have to call `update_state(....) anywhere.`
estimator_spec = tf.estimator.EstimatorSpec(mode, model.loss, eval_metric_ops=metrics_ops)
知道如何摆脱该错误吗?