2

我正在尝试在 TensorFlow (tensorflow==1.13.1) 中训练一个三元组损失模型,但在第一次迭代后损失会转到 NAN。我已经从这个论坛阅读了很多提示,但到目前为止没有任何效果。较小的学习率(从 0.01 更改为 0.0001)。还尝试了更大的批次大小,因为有人建议如果批次太小,批次可能不包含完整的三元组。但是,我使用以下代码自己构建了批次:'''

def get_triplets(self):
    a = random.choice(self.anchors)
    p = random.choice(self.positives[a])
    n = random.choice(self.negatives[a])

    return self.images[a], self.images[p], self.images[n]

def get_triplets_batch(self, n):
    idxs_a, idxs_p, idxs_n = [], [], []
    for _ in range(n):
        a, p, n = self.get_triplets()
        idxs_a.append(a)
        idxs_p.append(p)
        idxs_n.append(n)
    return idxs_a, idxs_p, idxs_n

我还看到人们在预测中添加了一个小数字来避免这个 NAN。我不知道我是否必须为三元组丢失或在哪里这样做。训练代码是这样的

'''

_,l,summary_str = sess.run([train_step,loss,merged],feed_dict={anchor_input:batch_anchor,positive_input:batch_positive,negative_input:batch_negative})

我是使用 StackOverflow 的新手,所以我想包含我在 Github 中使用的文件,以防我提供的信息太少。 https://github.com/AndreasVer/Triplet-loss-tensorflow.git

4

1 回答 1

0

我已经发布了一个用于生成三元组的包,以避免 NaN 丢失。这通常发生在您的批次没有任何阳性对时。

https://github.com/ma7555/kerasgen(免责声明:我是所有者)

于 2022-02-24T01:18:28.137 回答