1

这是我的示例 LSTM 类。表示英文文本字符串中字符的输入 ID 被编码为 one-hot 向量并馈送到单个 LSTM 层。

class MyLSTM(Chain):
    def __init__(self, vocab_size, hidden_size):
        super(MyLSTM, self).__init__(
            mid=L.LSTM(vocab_size, hidden_size),
            out=L.Linear(hidden_size, vocab_size),
        )
        self.W = np.identity(vocab_size).astype(np.float32)

    def reset_state(self):
        self.mid.reset_state()

    def __call__(self, x):
        x_1hot = F.embed_id(x, self.W)
        h = self.mid(x_1hot)
        y = self.out(h)
        return y

完整的代码在这里,只需将它指向一个示例 txt 文件,它应该运行: https ://gist.github.com/chris838/cea1987c38e0f29a2a514ad229454c0e

在第一个时期,我每秒运行 20 次迭代。到第 3 个 epoch 时,单次迭代需要一秒钟以上!

4

0 回答 0