0

我正在尝试建立一个具有相当复杂的基础网络的连体模型。构建基础网络后,我使用以下代码构建我的连体网络:

base_network=create_base_model(0.2)
img1=Input(shape=(256,256,3))
img2=Input(shape=(256,256,3))
text_input1 = Input(shape=(), dtype=tf.string, name='text_1')
text_input2 = Input(shape=(), dtype=tf.string, name='text_2')
output1= base_network([img1, text_input1])
output2= base_network([img2, text_input2])
distance = Lambda(euclidean_distance)([output1, output2])
siamese_model = Model([[img1,text_input1], [img2, text_input2]], distance)

基础网络的输出形式model

model=Model(inputs=[input1,input2], outputs=[z])

问题是在训练连体网络之后,我想使用基础网络的输出作为嵌入,这样我就可以运行无监​​督学习算法。但是,在训练 siamese 网络时,我想一次训练 10 个 epoch,然后保存它并在需要时继续训练。在这种情况下,当我保存并重新加载连体模型时,我不确定如何保存/访问基础网络。例如,我得到了需要 2 个输入的连体模型的下图(我的基本模型使用 2 个输入,所以技术上我有 4 个输入,如图所示),但我想使用只需要 1 个输入的基本模型训练(技术上是 2,因为我的基本模型使用 2)。

谁能给我关于如何使用保存的连体模型加载更新的基本模型的建议,或者是否有更好的方法首先保存它?

非常感谢。

在此处输入图像描述

4

1 回答 1

0
if epoch %5 == 0 
   path = f'/tmp/model{epoch}.h5'
   base_network.save(path)

base_network = tf.keras.models.load_model(path)

这不好吗?

于 2021-04-29T03:17:55.007 回答