1

Keras我用aMLP和 a训练了两个神经网络Bidirectional LSTM

我的任务是预测句子中的单词顺序,因此对于每个单词,神经网络必须输出一个实数。当处理一个包含 N 个单词的句子时,将输出中的 N 个实数排序,以获得表示单词位置的整数。

我在数据集上使用相同的数据集和相同的预处理。唯一不同的是,在LSTM数据集中我添加了填充以获得相同长度的序列。

在预测阶段,LSTM我排除了从填充向量创建的预测,因为我在训练阶段屏蔽了它们。

MLP架构:

mlp = keras.models.Sequential()

# add input layer
mlp.add(
    keras.layers.Dense(
        units=training_dataset.shape[1],
        input_shape = (training_dataset.shape[1],),
        kernel_initializer=keras.initializers.RandomUniform(minval=-0.05, maxval=0.05, seed=None),
        activation='relu')
    )

# add hidden layer
mlp.add(
    keras.layers.Dense(
        units=training_dataset.shape[1] + 10,
        input_shape = (training_dataset.shape[1] + 10,),
        kernel_initializer=keras.initializers.RandomUniform(minval=-0.05, maxval=0.05, seed=None),
        bias_initializer='zeros',
        activation='relu')
    )

# add output layer
mlp.add(
    keras.layers.Dense(
        units=1,
        input_shape = (1, ),
        kernel_initializer=keras.initializers.RandomUniform(minval=-0.05, maxval=0.05, seed=None),
        bias_initializer='zeros',
        activation='linear')
    )

双向 LSTM 架构:

model = tf.keras.Sequential()
model.add(Masking(mask_value=0., input_shape=(timesteps, features)))
model.add(Bidirectional(LSTM(units=20, return_sequences=True), input_shape=(timesteps, features)))
model.add(Dropout(0.2))
model.add(Dense(1, activation='linear'))

使用 LSTM 可以更好地解决该任务,它应该可以很好地捕获单词之间的依赖关系。

但是,用MLP我取得了很好的成绩,但用LSTM的结果却很糟糕。

由于我是初学者,有人能理解我的LSTM架构有什么问题吗?我快疯了。

提前致谢。

4

1 回答 1

1

对于这个问题,我对 MLP 表现更好其实并不感到惊讶。

LSTM 的体系结构,无论是否是双向的,都假设位置对结构非常重要。彼此相邻的词比更远的词更有可能相关。

但是对于您的问题,您已经删除了该位置并正在尝试恢复它。对于这个问题,具有全局信息的 MLP 可以在排序方面做得更好。

也就是说,我认为在改进 LSTM 模型方面仍有一些工作要做。


您可以做的一件事是确保每个模型的复杂性相似。您可以使用count_params轻松完成此操作。

mlp.count_params()
model.count_params()

如果我不得不猜测,你的 LSTM 要小得多。只有 20 个units,对于 NLP 问题来说似乎很小。我用于512产品分类问题来处理字符级信息(大小为 128 的词汇,大小为 50 的嵌入)。在更大数据集(如AWD-LSTM )上训练的词级模型可以进入数千个units.

所以你可能想增加这个数字。您可以通过增加unitsLSTM 中的数量直到参数计数相似来获得两个模型之间的苹果对苹果的比较。但是你不必止步于此,你可以继续增加尺寸,直到你开始过度拟合或者你的训练开始花费太长时间。

于 2020-04-18T21:11:47.490 回答