9

在 keras 中,我想自定义我的损失函数,它不仅需要 (y_true, y_pred) 作为输入,还需要使用网络内部层的输出作为输出层的标签。这张图显示了网络布局

这里,内部输出是 xn,它是一个 1D 特征向量。在右上角,输出为xn',即xn的预测。换句话说,xn 是 xn' 的标签。

而 [Ax, Ay] 传统上称为 y_true,而 [Ax',Ay'] 是 y_pred。

我想将这两个损失组件合二为一,共同训练网络。

任何想法或想法都非常感谢!

4

3 回答 3

12

我已经找到了一条出路,以防万一有人在寻找相同的方法,我在此处发布(基于本文中给出的网络):

这个想法是定义自定义的损失函数并将其用作网络的输出。(符号:A是变量的真实标签,是变量AA'预测值A

def customized_loss(args):
    #A is from the training data
    #S is the internal state
    A, A', S, S' = args 
    #customize your own loss components
    loss1 = K.mean(K.square(A - A'), axis=-1)
    loss2 = K.mean(K.square(S - S'), axis=-1)
    #adjust the weight between loss components
    return 0.5 * loss1 + 0.5 * loss2

 def model():
     #define other inputs
     A = Input(...) # define input A
     #construct your model 
     cnn_model = Sequential()
     ...
     # get true internal state
     S = cnn_model(prev_layer_output0)
     # get predicted internal state output
     S' = Dense(...)(prev_layer_output1)
     # get predicted A output
     A' = Dense(...)(prev_layer_output2)
     # customized loss function
     loss_out = Lambda(customized_loss, output_shape=(1,), name='joint_loss')([A, A', S, S'])
     model = Model(input=[...], output=[loss_out])
     return model

  def train():
      m = model()
      opt = 'adam'
      model.compile(loss={'joint_loss': lambda y_true, y_pred:y_pred}, optimizer = opt)
      # train the model 
      ....
于 2017-01-23T02:37:57.010 回答
0

我对此实施持保留意见。在合并层计算的损失被传播回两个合并的分支。通常,您希望仅通过一层传播它。

于 2019-07-06T10:08:54.760 回答
0

首先,您应该使用Functional API。然后你应该将网络输出定义为输出加上内部层的结果,将它们合并为一个输出(通过连接),然后制作一个自定义损失函数,然后将合并的输出分成两部分并进行损失计算在其自己的。

就像是:

def customLoss(y_true, y_pred):
    #loss here
    internalLayer = Convolution2D()(inputs) #or other layers
    internalModel = Model(input=inputs, output=internalLayer)
    tmpOut = Dense(...)(internalModel)
    mergedOut = merge([tmpOut, mergedOut], mode = "concat", axis = -1)
    fullModel = Model(input=inputs, output=mergedOut)

    fullModel.compile(loss = customLoss, optimizer = "whatever")
于 2017-01-16T08:33:10.650 回答