gather_tree
in究竟是如何contrib.seq2seq
工作的?我可以看到它采用了预测的 ID 和梁父 ID,并以某种方式返回了最终的梁,但引擎盖下到底发生了什么?似乎没有任何 Python 代码库可供我检查以找出答案。API不是很解释;
有代码源tf.contrib.seq2seq.gather_tree
吗?我正在使用 TensorFlow 1.3,往里gen_beam_search_ops.py
看似乎没有帮助。
gather_tree
in究竟是如何contrib.seq2seq
工作的?我可以看到它采用了预测的 ID 和梁父 ID,并以某种方式返回了最终的梁,但引擎盖下到底发生了什么?似乎没有任何 Python 代码库可供我检查以找出答案。API不是很解释;
有代码源tf.contrib.seq2seq.gather_tree
吗?我正在使用 TensorFlow 1.3,往里gen_beam_search_ops.py
看似乎没有帮助。
代码详细如下:
def gather_tree_py(values, parents):
"""Gathers path through a tree backwards from the leave nodes. Used
to reconstruct beams given their parents."""
beam_length = values.shape[0]
num_beams = values.shape[1]
res = np.zeros_like(values)
res[-1, :] = values[-1, :]
for beam_id in range(num_beams):
parent = parents[-1][beam_id]
for level in reversed(range(beam_length - 1)):
res[level, beam_id] = values[level][parent]
parent = parents[level][parent]
return np.array(res).astype(values.dtype)
def gather_tree(values, parents):
"""Tensor version of gather_tree_py"""
res = tf.py_func(
func=gather_tree_py, inp=[values, parents], Tout=values.dtype)
res.set_shape(values.get_shape().as_list())
return res