如果我想更改 RNN 单元(例如 GRU 单元)中的计算规则,我应该怎么做?
考虑到效率问题,我不想通过 for 或 while 循环来实现它。
我查看了 pytorch 的源代码,但似乎 rnn 单元的主要组件是用我无法找到和修改的 c 代码实现的。你可以通过一个例子来回答这个问题:在没有现有版本的情况下实现 GRU 单元。
谢谢~
如果我想更改 RNN 单元(例如 GRU 单元)中的计算规则,我应该怎么做?
考虑到效率问题,我不想通过 for 或 while 循环来实现它。
我查看了 pytorch 的源代码,但似乎 rnn 单元的主要组件是用我无法找到和修改的 c 代码实现的。你可以通过一个例子来回答这个问题:在没有现有版本的情况下实现 GRU 单元。
谢谢~
是的,您“通过 for 或 while 循环”实现它。自 Pytorch 1.0 以来,JIT https://pytorch.org/docs/stable/jit.html运行良好(由于最近对 JIT 的改进,使用 PyTorch 的最新 git 版本可能更好),并且取决于您的网络和实现可以与原生 PyTorch C++ 实现一样快(但仍比 CuDNN 慢)。
您可以在https://github.com/pytorch/benchmark/blob/master/rnns/fastrnns/custom_lstms.py查看示例实现