我创建了一个自定义 Keras 指标,类似于下面的演示实现:
import tensorflow as tf
class MyMetric(tf.keras.metrics.Mean):
def __init__(self, name='my_metric', dtype=None):
super(MyMetric, self).__init__(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
return super(MyMetric, self).update_state(
y_pred, sample_weight=sample_weight)
我已将实现转换为带有 init/main 文件的 Python 模块,并将路径添加到系统的PYTHONPATH
. 我可以在训练 Keras 模型时使用该指标。
不幸的是,我还没有找到一种方法使自定义指标可用于 TensorFlow 模型分析 (TFMA)。
在我的交互式上下文笔记本中,我可以在创建eval_config
.
import tensorflow as tf
import tensorflow_model_analysis as tfma
from mymetric.metric import MyMetric
metrics = [MyMetric()]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(label_key='label_xf')],
metrics_specs=metrics_specs,
slicing_specs=[tfma.SlicingSpec()]
)
evaluator = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
baseline_model=model_resolver.outputs['model'],
eval_config=eval_config)
当我尝试执行时evaluator
,该指标在指标规范中列出
metrics_specs {
metrics {
class_name: "MyMetric"
config: "{\"dtype\": \"float32\", \"name\": \"my_metric\"}"
threshold {
}
}
}
但执行失败并出现错误
ValueError: Unknown metric function: MyMetric
由于度量计算是通过 Apache Beam 的executor.Do
函数执行的,我假设 Beam 找不到模块(即使它在 PYTHONPATH 上)。如果是这种情况,如何使模块在 PYTHONPATH 配置之外对 Apache Beam 可用?
追溯:
/usr/local/lib/python3.6/dist-packages/tensorflow_model_analysis/metrics/metric_specs.py in _deserialize_tf_metric(metric_config, custom_objects)
741 cls_name, cfg = _tf_class_and_config(metric_config)
742 with tf.keras.utils.custom_object_scope(custom_objects):
--> 743 return tf.keras.metrics.deserialize({'class_name': cls_name, 'config': cfg})
744
745
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/metrics.py in deserialize(config, custom_objects)
3441 module_objects=globals(),
3442 custom_objects=custom_objects,
-> 3443 printable_module_name='metric function')
3444
3445
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
345 config = identifier
346 (cls, cls_config) = class_and_config_for_serialized_keras_object(
--> 347 config, module_objects, custom_objects, printable_module_name)
348
349 if hasattr(cls, 'from_config'):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
294 cls = get_registered_object(class_name, custom_objects, module_objects)
295 if cls is None:
--> 296 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
297
298 cls_config = config['config']
ValueError: Unknown metric function: MyMetric