0

我正在构建一个网络,它将字符串拆分为单词,将单词拆分为字符,嵌入每个字符,然后通过将字符聚合为单词并将单词聚合为字符串来计算该字符串的向量表示。使用双向 gru 层执行聚合并注意。
为了测试这个东西,假设我对这个字符串中的 5 个单词和 5 个字符感兴趣。在这种情况下,我的转变是:

["Some string"] -> ["Some","strin","","",""] -> 
["Some_","string","_____","_____","_____"] where _ is the padding symbol ) -> 
[[1,2,3,4,0],[1,5,6,7,8],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]] (shape 5x5)

接下来我有一个嵌入层,它将每个字符变成一个长度为 6 的嵌入向量。所以我的特征变成了一个 5x5x6 矩阵。然后我将此输出传递给双向 gru 层并执行一些其他操作,这些操作在这种情况下并不重要,我相信。

问题是当我用迭代器运行它时,比如

for string in strings:
    output = model(string)

它似乎工作得很好(字符串是从 5x5 的切片创建的 tf 数据集),所以它是一堆 5 x 5 矩阵。

但是,当我转到训练或使用预测等功能在数据集级别工作时,模型会失败:

model.predict(strings.batch(1))
ValueError: Input 0 of layer bidirectional is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: (None, 5, 5, 6)

据我从文档中了解到,双向层将 3d 张量作为输入:[batch, timesteps, feature],因此在这种情况下,我的输入形状应如下所示:[batch_size,timesteps,(5,5,6)]

所以问题是我应该对输入数据应用哪种转换来获得这种形状?

4

1 回答 1

1

对于双向输入层,如果您使用 GRU,请使用return_sequences=True, 来获得 3 维输出。由于 GRU 输出是 2D,return_sequences 将为您提供 3D 输出。对于堆叠的双向层输入应该是 3D 形状。

示例代码

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential()

model.add(
    layers.Bidirectional(layers.GRU(64, return_sequences=True), input_shape=(5, 10))
)
model.add(layers.Bidirectional(layers.GRU(32)))
model.add(layers.Dense(10))

model.summary()

输出

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
bidirectional_3 (Bidirection (None, 5, 128)            38400     
_________________________________________________________________
bidirectional_4 (Bidirection (None, 64)                41216     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650       
=================================================================
Total params: 80,266
Trainable params: 80,266
Non-trainable params: 0
___________________________
于 2021-10-20T02:18:39.007 回答