我一直在寻找如何在 Keras 的二维阵列上实现地球移动器损失。我的输入是 2d(不是图像)。我已经从https://github.com/master/nima/blob/4a685993d0e5942cf70de54f1c31a218827ccba3/nima.py#L31尝试了以下代码, 但它似乎不起作用。任何人都可以帮忙吗?
def ecdf(p):
n = p.get_shape().as_list()[1]
indices = tril_indices(n)
indices = tf.transpose(tf.stack([indices[1], indices[0]]))
ones = tf.ones([n * (n + 1) / 2])
triang = tf.scatter_nd(indices, ones, [n, n])
return tf.matmul(p, triang)
def emd_loss(p, p_hat, r=2, scope=None):
with tf.name_scope(scope, 'EmdLoss', [p, p_hat]):
ecdf_p = ecdf(p)
ecdf_p_hat = ecdf(p_hat)
emd = tf.reduce_mean(tf.pow(tf.abs(ecdf_p - ecdf_p_hat), r),
axis=-1)
emd = tf.pow(emd, 1 / r)
return tf.reduce_mean(emd)