在阅读了Bahdanau 论文并将其翻译成当前的 tf.contrib.seq2seq API 之后,我对应该输入解码器的内容感到困惑。特别是,TrainingHelper 看起来应该收到一个时移的标签列表。
下面是我的工作示例,但我不确定它是否正确。
# Given:
# annotations: encoder outputs, reshaped to
# (batch_size, time, encoder_size)
# labels: ground truth, shaped (batch_size, FORECAST_HORIZON)
if params.get('ATTENTION') == 'Bahdanau':
bahdanau = tf.contrib.seq2seq.BahdanauAttention(
num_units=ATTENTION_SIZE,
memory=annotations,
normalize=False,
name='BahdanauAttention')
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
cell=tf.nn.rnn_cell.BasicLSTMCell(DECODER_SIZE, forget_bias=1.0),
attention_mechanism=bahdanau,
output_attention=False,
name="attention_wrapper")
helper = tf.contrib.seq2seq.TrainingHelper(
inputs=annotations, # ??????
sequence_length=[WINDOW_LENGTH]*BATCH_SIZE,
name="TrainingDecoderHelper")
注意倒数第三行。
TrainingHelper 是否应该将编码器注释输入注意力包装的解码器系统?
- 亲:如果
inputs
不是 shape likeannotations
,则 AttentionWrapper 最终会抱怨形状 - 这种形状在系统中出现的唯一位置是在编码器中。 - 缺点:如果这是正确的,解码器从哪里获得基本事实?
- 缺点:注意力包装的解码器(
attn_cell
)已经知道从哪里获取注释(这不是注意力机制的重点吗?)
无论如何,实际上,我得到了一个可训练的系统,但在我看来有些可疑(包括它相对于简单的 LSTM 表现不佳,但目前这绝对是切线)。