我正在玩trax库中的自注意力模型。
当我设置时n_heads=1
,一切正常。但是当我设置时n_heads=2
,我的代码会中断。
我只使用输入激活和一个 SelfAttention 层。
这是一个最小的代码:
import trax
import numpy as np
attention = trax.layers.SelfAttention(n_heads=2)
activations = np.random.randint(0, 10, (1, 100, 1)).astype(np.float32)
input = (activations, )
init = attention.init(input)
output = attention(input)
但我有一个错误:
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/layers/research/efficient_attention.py, line 1637, in forward_unbatched_h
return forward_unbatched(*i_h, weights=w_h, state=s_h)
File [...]/layers/research/efficient_attention.py, line 1175, in forward_unbatched
q_info = kv_info = np.arange(q.shape[-2], dtype=np.int32)
IndexError: tuple index out of range
我做错了什么?