-1

我看到了一个包含以下功能的人脸检测模型。但我不明白 expand_dims 函数的用途。谁能解释一下它是什么以及我们为什么使用它?

def get_embedding(model,face_pixels):
    face_pixels=face_pixels.astype('float32')
    mean, std=face_pixels.mean(),face_pixels.std()
    face_pixels=(face_pixels-mean)/std
    samples=expand_dims(face_pixels,axis=0)
    yhat=model.predict(samples)
    return yhat[0]
4

1 回答 1

0

tf.keras.Conv2D层期望输入具有 4D 形状:

(n_samples, height, width, channels)

大多数加载图像的库将以 3D 形式加载,如下所示:

(height, width, channels)

通过使用np.expand_dims(image, axis=0)or tf.expand_dims(image, axis=0),您可以在开始时添加批处理维度,从而有效地将数据转换为 KerasConv2D图层所需的 4D 格式。例如:

(224, 224, 3)

至:

(1, 224, 224, 3)

如果您提供Conv2D3D 数据,它将提供如下内容:

ValueError:检查输入时出错:预期 conv2d_19_input 有 4 个维度,但得到了形状为 (60000, 28, 28) 的数组

于 2021-03-01T17:11:37.493 回答