我正在尝试在文本生成模型中实现波束搜索解码策略。这是我用来解码输出概率的函数。
def beam_search_decoder(data, k):
sequences = [[list(), 0.0]]
# walk over each step in sequence
for row in data:
all_candidates = list()
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score - torch.log(row[j])]
all_candidates.append(candidate)
# sort candidates by score
ordered = sorted(all_candidates, key=lambda tup:tup[1])
sequences = ordered[:k]
return sequences
现在你可以看到这个函数是在考虑到 batch_size 1 的情况下实现的。为批量大小添加另一个循环将使算法O(n^4)
。和现在一样慢。有什么办法可以提高这个功能的速度。我的模型输出通常是(32, 150, 9907)
遵循格式的大小(batch_size, max_len, vocab_size)