假设我有一个W
形状矩阵,(n_words, model_dim)
其中n_words
是句子中model_dim
的单词数,是表示单词向量的空间的维度。计算这些向量的移动平均值的最快方法是什么?
例如,窗口大小为 2(窗口长度 = 5),我可以有这样的东西(这会引发错误TypeError: JAX 'Tracer' objects do not support item assignment
):
from jax import random
import jax.numpy as jnp
# Fake word vectors (17 words vectors of dimension 32)
W = random.normal(random.PRNGKey(0), shape=(17, 32))
ws = 2 # window size
N = W.shape[0] # number of words
new_W = jnp.zeros(W.shape)
for i in range(N):
window = W[max(0, i-ws):min(N, i+ws+1)]
n = window.shape[0]
for j in range(n):
new_W[i] += W[j] / n
我想有一个更快的解决方案,jnp.convolve
但我不熟悉它。