2

我创建了一个自定义 Keras 指标,类似于下面的演示实现:

import tensorflow as tf

class MyMetric(tf.keras.metrics.Mean):

    def __init__(self, name='my_metric', dtype=None):
        super(MyMetric, self).__init__(name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        return super(MyMetric, self).update_state(
            y_pred, sample_weight=sample_weight)

我已将实现转换为带有 init/main 文件的 Python 模块,并将路径添加到系统的PYTHONPATH. 我可以在训练 Keras 模型时使用该指标。

不幸的是,我还没有找到一种方法使自定义指标可用于 TensorFlow 模型分析 (TFMA)。

在我的交互式上下文笔记本中,我可以在创建eval_config.

import tensorflow as tf
import tensorflow_model_analysis as tfma 
from mymetric.metric import MyMetric

metrics = [MyMetric()]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='label_xf')],
        metrics_specs=metrics_specs,
        slicing_specs=[tfma.SlicingSpec()]
)
evaluator = Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'], 
    baseline_model=model_resolver.outputs['model'],
    eval_config=eval_config)

当我尝试执行时evaluator,该指标在指标规范中列出

metrics_specs {
  metrics {
    class_name: "MyMetric"
    config: "{\"dtype\": \"float32\", \"name\": \"my_metric\"}"
    threshold {
    }
  }
}

但执行失败并出现错误

ValueError: Unknown metric function: MyMetric

由于度量计算是通过 Apache Beam 的executor.Do函数执行的,我假设 Beam 找不到模块(即使它在 PYTHONPATH 上)。如果是这种情况,如何使模块在 PYTHONPATH 配置之外对 Apache Beam 可用?

追溯:

/usr/local/lib/python3.6/dist-packages/tensorflow_model_analysis/metrics/metric_specs.py in _deserialize_tf_metric(metric_config, custom_objects)
    741   cls_name, cfg = _tf_class_and_config(metric_config)
    742   with tf.keras.utils.custom_object_scope(custom_objects):
--> 743     return tf.keras.metrics.deserialize({'class_name': cls_name, 'config': cfg})
    744 
    745 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/metrics.py in deserialize(config, custom_objects)
   3441       module_objects=globals(),
   3442       custom_objects=custom_objects,
-> 3443       printable_module_name='metric function')
   3444 
   3445 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    345     config = identifier
    346     (cls, cls_config) = class_and_config_for_serialized_keras_object(
--> 347         config, module_objects, custom_objects, printable_module_name)
    348 
    349     if hasattr(cls, 'from_config'):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    294   cls = get_registered_object(class_name, custom_objects, module_objects)
    295   if cls is None:
--> 296     raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
    297 
    298   cls_config = config['config']

ValueError: Unknown metric function: MyMetric
4

1 回答 1

2

您需要指定模块,以便 TFX 知道在哪里可以找到您的 MyMetric 类。一种方法是将其指定为度量规范的一部分:

from tensorflow_model_analysis import config

metric_config = [config.MetricConfig(class_name='MyMetric', module='mymodule.mymetric')]

metrics_specs = [config.MetricsSpec(metrics=metric_config)]

您还需要创建一个名为mymodule并将您的MyMetric类放入其中mymetric.py以使其工作的模块。还要确保可以从您执行代码的位置访问该模块(如果您已将其添加到您的 PYTHONPATH 应该是这种情况)。

于 2020-09-28T13:41:33.960 回答