您可以使用tf.slice()仅提取模型输出的前三个元素。
import tensorflow as tf
# enabling eager mode to demo the slice fn
tf.compat.v1.enable_eager_execution()
import numpy as np
# just creating a random array dimesions size (2, 7)
# where 2 is just an arbitrary value chosen for the batch dimension
out = np.arange(0,14).reshape(2,7)
print(out)
# array([[ 0, 1, 2, 3, 4, 5, 6],
# [ 7, 8, 9, 10, 11, 12, 13]])
# put it in a tf variable
out_tf = tf.Variable(out)
# now using the slice operator
xyz = tf.slice(out_tf, begin=[0, 0], size=[-1,3])
# lets see what it looked like
print(xyz)
# <tf.Tensor: id=11, shape=(2, 3), dtype=int64, numpy=
# array([[0, 1, 2],
# [7, 8, 9]])>
可以将这样的内容包装到您的自定义指标中以获得您需要的内容。
def xyz_median(y_true, y_pred):
"""get the median of just the X,Y,Z coords
UNTESTED though :)
"""
# slice to get just the xyz
xyz = tf.slice(out_tf, begin=[0, 0], size=[-1,3])
median = tfp.stats.percentile(xyz, q=50, interpolation='linear')
return median