0

我根据这篇论文编写了一个自定义模型,用于 TF2.0 中的样式转换。简而言之,所提出算法的损失函数需要评估 3 个损失分量。该模型接受 2 个输入图像,比如 Ic、Is(c 代表内容,s 代表风格),然后弹出一个拼贴图像 O。

在单个训练步骤中,网络接收以下一对作为输入并弹出相应的图像:

  • Ic,是 -> O(需要)
  • Ic, Ic -> O 身份损失1
  • 是,是 -> O 身份损失2

然后一个特征网络评估不同的损失分量(因为主网络需要 3 次前向传递,但不是可训练网络的一部分,因此它具有不可训练的权重)

代码如下所示:

class StyleTranasfer(keras.Model):
  def __init__(self, autoencoder,net):
    super(StyleTranasfer, self).__init__()
    self.autoencoder = autoencoder
    self.net=net
    

  def compile(self, optimizer,loss):
    super(StyleTranasfer, self).compile()
    self.optimizer = optimizer
    self.loss_fn=loss

  @tf.function
  def call(self,input,training=False):
    content_images,style_images=input
    return self.autoencoder((tf.image.resize(content_images,[224,224]),
                   tf.image.resize(style_images,[224,224])))

   
    
  @tf.function
  def test_step(self,data):

    stylyzed_output=self.call(data)
    stylyzed_content_output=self.call((content_images,content_images))
    stylyzed_style_output=self.call((style_images,style_images))
  
    d_loss,style_loss, content_loss, identity_loss_1 ,identity_loss_2 = self.loss_fn(
                                                                      self.prepro(stylyzed_output),
                                                                      self.prepro(content_images),
                                                                      self.prepro(style_images),
                                                                  self.prepro(stylyzed_style_output),
                                                                self.prepro(stylyzed_content_output),
                                                                      self.net)
    return {"loss": loss, 
          "style_loss": style_loss, 
          "content_loss": content_loss,
          "identity_loss_1": identity_loss_1 ,
          "identity_loss_2": identity_loss_2}
  
  @tf.function
  def train_step(self, data):
  
    content_images,style_images=data
  
  
  
    with tf.GradientTape() as tape:
      
      stylyzed_output=self.call(data)
      stylyzed_content_output=self.call((content_images,content_images))
      stylyzed_style_output=self.call((style_images,style_images))

      loss,style_loss, content_loss, identity_loss_1 ,identity_loss_2 = self.loss_fn( 
                                                                  self.prepro(stylyzed_output),          
                                                                  self.prepro(content_images),
                                                                  self.prepro(style_images),
                                                                  self.prepro(stylyzed_style_output),
                                                                self.prepro(stylyzed_content_output),
                                                                          self.net)
    grads = tape.gradient(loss, self.autoencoder.trainable_weights)
    self.optimizer.apply_gradients(zip(grads, self.autoencoder.trainable_weights))
    return {"loss": loss, 
          "style_loss": style_loss, 
          "content_loss": content_loss,
          "identity_loss_1": identity_loss_1 ,
          "identity_loss_2": identity_loss_2}


  def prepro(self,img):
    return  tf.keras.applications.vgg19.preprocess_input(128.0*img)

我可以很容易地训练模型,但是当我尝试 save_weights 时,我得到:

Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-20-785ae311d57f>", line 4, in <module>
model.save_weights('prova.h5')
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py", line 2108, in save_weights
hdf5_format.save_weights_to_hdf5_group(f, self.layers)
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/hdf5_format.py", line 642, in save_weights_to_hdf5_group
param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
File "/usr/local/lib/python3.7/dist-packages/h5py/_hl/group.py", line 139, in create_dataset
self[name] = dset
File "/usr/local/lib/python3.7/dist-packages/h5py/_hl/group.py", line 373, in __setitem__
h5o.link(obj.id, self.id, name, lcpl=lcpl, lapl=self._lapl)
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "h5py/h5o.pyx", line 202, in h5py.h5o.link
RuntimeError: Unable to create link (name already exists)

似乎某些权重被复制(包括名称)在保存它们时出现上升错误。

...有人知道吗?

4

0 回答 0