作为练习,我只需要使用密集层来执行文本分类。我想利用词嵌入,问题是数据集是 3D 的(样本、句子的词、嵌入维度)。我可以将 3D 数据集输入到密集层吗?
谢谢
作为练习,我只需要使用密集层来执行文本分类。我想利用词嵌入,问题是数据集是 3D 的(样本、句子的词、嵌入维度)。我可以将 3D 数据集输入到密集层吗?
谢谢
如keras 文档中所述,您可以使用 3D(或更高等级)数据作为 Dense 层的输入,但输入首先会被展平:
注意:如果层的输入具有大于 2 的秩,则它在与内核的初始点积之前被展平。
这意味着如果您的输入具有 shape (batch_size, sequence_length, dim)
,那么密集层将首先将您的数据展平为 shape (batch_size * sequence_length, dim)
,然后像往常一样应用密集层。输出将具有 shape (batch_size, sequence_length, hidden_units)
。这实际上与应用内核大小为 1 的 Conv1D 层相同,使用 Conv1D 层而不是 Dense 层可能更明确。