0

这是我用来编译代码的损失函数:

def yolo_loss(y_true, yolo_outputs):    
    anchors = np.array([10.0, 14.0, 23.0, 27.0, 37.0, 58.0, 81.0, 82.0, 135.0, 169.0, 344.0, 319.0]).reshape(-1, 2)
    num_classes = 1
    num_layers = len(anchors) // 3  # default setting
    ignore_thresh = 0.0005

    anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] if num_layers == 3 else [[3, 4, 5], [1, 2, 3]]
    input_shape = tf.keras.backend.cast(tf.keras.backend.shape(yolo_outputs[0])[1:3] * 32, tf.keras.backend.dtype(y_true[0]))
    grid_shapes = [
        tf.keras.backend.cast(tf.keras.backend.shape(yolo_outputs[l])[1:3], tf.keras.backend.dtype(y_true[0])) for l in range(num_layers)
    ]
    loss = 0
    m = tf.keras.backend.shape(yolo_outputs[0])[0]  # batch size, tensor
    mf = tf.keras.backend.cast(m, tf.keras.backend.dtype(yolo_outputs[0]))

    for l in range(num_layers):
        object_mask = y_true[l][..., 4:5]
        true_class_probs = y_true[l][..., 5:]

        grid, raw_pred, pred_xy, pred_wh = yolo_head(yolo_outputs[l], anchors[anchor_mask[l]], num_classes, input_shape, calc_loss=True)
        pred_box = tf.keras.backend.concatenate([pred_xy, pred_wh])

        # Darknet raw box to calculate loss.
        raw_true_xy = y_true[l][..., :2] * grid_shapes[l][::-1] - grid
        raw_true_wh = tf.keras.backend.log(y_true[l][..., 2:4] / anchors[anchor_mask[l]] * input_shape[::-1])
        raw_true_wh = tf.keras.backend.switch(object_mask, raw_true_wh, tf.keras.backend.zeros_like(raw_true_wh))  # avoid log(0)=-inf
        box_loss_scale = 2 - y_true[l][..., 2:3] * y_true[l][..., 3:4]

        # Find ignore mask, iterate over each of batch.
        ignore_mask = tf.TensorArray(tf.keras.backend.dtype(y_true[0]), size=1, dynamic_size=True)
        object_mask_bool = tf.keras.backend.cast(object_mask, "bool")

        def loop_body(b, ignore_mask):
            true_box = tf.boolean_mask(y_true[l][b, ..., 0:4], object_mask_bool[b, ..., 0])
            iou = box_iou(pred_box[b], true_box)
            best_iou = tf.keras.backend.max(iou, axis=-1)
            ignore_mask = ignore_mask.write(b, tf.keras.backend.cast(best_iou < ignore_thresh, tf.keras.backend.dtype(true_box)))
            return b + 1, ignore_mask

        _, ignore_mask = tf.while_loop(lambda b, *args: b < m, loop_body, [0, ignore_mask])
        ignore_mask = ignore_mask.stack()
        ignore_mask = tf.expand_dims(ignore_mask, -1)

        raw_true_xy = tf.keras.backend.reshape(raw_true_xy, shape=tf.shape(raw_pred[..., 0:2]))
        object_mask = tf.keras.backend.reshape(object_mask, shape=tf.shape(raw_pred[..., 4:5]))
        true_class_probs = tf.keras.backend.reshape(true_class_probs, shape=tf.shape(raw_pred[..., 5:]))
        xy_loss = object_mask * box_loss_scale * tf.keras.backend.binary_crossentropy(raw_true_xy, raw_pred[..., 0:2], from_logits=True)
        wh_loss = object_mask * box_loss_scale * 0.5 * tf.keras.backend.square(raw_true_wh - raw_pred[..., 2:4])
        confidence_loss = (
            object_mask * tf.keras.backend.binary_crossentropy(object_mask, raw_pred[..., 4:5], from_logits=True)
            + (1 - object_mask) * tf.keras.backend.binary_crossentropy(object_mask, raw_pred[..., 4:5], from_logits=True) * ignore_mask
        )
        class_loss = object_mask * tf.keras.backend.binary_crossentropy(true_class_probs, raw_pred[..., 5:], from_logits=True)
        xy_loss = tf.keras.backend.sum(xy_loss) / mf
        wh_loss = tf.keras.backend.sum(wh_loss) / mf
        confidence_loss = tf.keras.backend.sum(confidence_loss) / mf
        class_loss = tf.keras.backend.sum(class_loss) / mf
        loss += xy_loss + wh_loss + confidence_loss + class_loss
    return loss


正在使用的模型主体是:https ://github.com/qqwweee/keras-yolo3/blob/master/yolo3/model.py#L89

调用 model.train_on_batch 时出现此错误 当我用类似的东西替换损失函数时,此错误得到缓解


def custom_loss_function(y_true, yolo_outputs):
    squared_difference = tf.square(y_true - yolo_outputs)
    return tf.reduce_mean(squared_difference, axis=-1)

因此,这让我认为损失函数会导致此错误

注:模型编译为:

model = tiny_yolo_body()
model.compile(optimizer="Adam", loss=yolo_loss)
4

0 回答 0