我正在编写一个自定义损失函数,它在忽略 nans 的同时计算均方误差。问题是我的数据是偶尔有 NaN 像素的图像。我只是想忽略这些 nan 像素并计算预测和数据之间的平方和误差,然后计算示例的平均值。如果我要在 Tensorflow 中为此编写一个函数,我会写:
def nanmean_squared_error(y_true, y_pred):
residuals = y_true - y_pred
residuals_no_nan = tf.where(tf.is_nan(residuals), tf.zeros_like(residuals), residuals)
sum_residuals = tf.reduce_sum(residuals_no_nan, [1, 2])
return sum_residuals
但是此代码不能用作自定义 Keras 损失函数。
我相信我可以使用 keras.backend.switch/zeros_like/sum 而不是 tensorflow 版本。但我找不到 tf.is_nan 的任何替代品。有没有人对如何实施这个有建议?