我正在尝试使用 keras基于论文Cardiologist-Level Arrhythmia Detection with Convolutional Neural Networks开发具有残差连接和批量归一化的一维卷积神经网络。这是到目前为止的代码:
# define model
x = Input(shape=(time_steps, n_features))
# First Conv / BN / ReLU layer
y = Conv1D(filters=n_filters, kernel_size=n_kernel, strides=n_strides, padding='same')(x)
y = BatchNormalization()(y)
y = ReLU()(y)
shortcut = MaxPooling1D(pool_size = n_pool)(y)
# First Residual block
y = Conv1D(filters=n_filters, kernel_size=n_kernel, strides=n_strides, padding='same')(y)
y = BatchNormalization()(y)
y = ReLU()(y)
y = Dropout(rate=drop_rate)(y)
y = Conv1D(filters=n_filters, kernel_size=n_kernel, strides=n_strides, padding='same')(y)
# Add Residual (shortcut)
y = add([shortcut, y])
# Repeated Residual blocks
for k in range (2,3): # smaller network for testing
shortcut = MaxPooling1D(pool_size = n_pool)(y)
y = BatchNormalization()(y)
y = ReLU()(y)
y = Dropout(rate=drop_rate)(y)
y = Conv1D(filters=n_filters * k, kernel_size=n_kernel, strides=n_strides, padding='same')(y)
y = BatchNormalization()(y)
y = ReLU()(y)
y = Dropout(rate=drop_rate)(y)
y = Conv1D(filters=n_filters * k, kernel_size=n_kernel, strides=n_strides, padding='same')(y)
y = add([shortcut, y])
z = BatchNormalization()(y)
z = ReLU()(z)
z = Flatten()(z)
z = Dense(64, activation='relu')(z)
predictions = Dense(classes, activation='softmax')(z)
model = Model(inputs=x, outputs=predictions)
# Compiling
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])
# Fitting
model.fit(train_x, train_y, epochs=n_epochs, batch_size=n_batch)
这是我正在尝试构建的简化模型的图表。
论文中描述的模型使用递增数量的过滤器:
该网络由 16 个残差块组成,每个块有 2 个卷积层。卷积层的滤波器长度均为 16,有 64k 个滤波器,其中 k 从 1 开始,每 4 个残差块递增。每个备用残差块对其输入进行 2 倍子采样,因此原始输入最终被 2^8 倍子采样。当残差块对输入进行二次采样时,相应的快捷连接也会使用具有相同子采样因子的 Max Pooling 操作对其输入进行二次采样。
但只有在每个 Conv1D 层中使用相同数量的过滤器,k=1,strides=1 和 padding=same,而不应用任何 MaxPooling1D,我才能使它工作。这些参数的任何更改都会导致张量大小不匹配并且无法编译并出现以下错误:
ValueError: Operands could not be broadcast together with shapes (70, 64) (70, 128)
有谁知道如何解决这种尺寸不匹配并使其发挥作用?
此外,如果输入有多个通道(或特征),则不匹配会更严重!有没有办法处理多个渠道?