我正在尝试将使用 tensorflow 层编写的 CNN 转换为在 tensorflow 中使用 keras api(我正在使用 TF 1.x 提供的 keras api),并且在编写自定义损失函数来训练模型时遇到问题。
根据本指南,在定义损失函数时,它需要参数(y_true, y_pred)
https://www.tensorflow.org/guide/keras/train_and_evaluate#custom_losses
def basic_loss_function(y_true, y_pred):
return ...
但是,在我看到的每个示例中,y_true都与模型直接相关(在简单的情况下,它是网络的输出)。在我的问题中,情况并非如此。如果我的损失函数依赖于一些与模型张量无关的训练数据,如何实现这一点?
具体来说,这是我的问题:
我正在尝试学习在成对图像上训练的图像嵌入。我的训练数据包括图像对和图像对之间匹配点的注释(图像坐标)。输入特征只是图像对,网络是在孪生配置中训练的。
我能够使用 tensorflow 层成功地实现这一点,并使用 tensorflow 估计器成功地训练它。我当前的实现从一个大型的 tf 记录数据库构建一个 tf 数据集,其中的特征是一个包含图像和匹配点数组的字典。在我可以轻松地将这些图像坐标数组输入到损失函数之前,但目前尚不清楚如何执行此操作。