2

假设我有一个W形状矩阵,(n_words, model_dim)其中n_words是句子中model_dim的单词数,是表示单词向量的空间的维度。计算这些向量的移动平均值的最快方法是什么?

例如,窗口大小为 2(窗口长度 = 5),我可以有这样的东西(这会引发错误TypeError: JAX 'Tracer' objects do not support item assignment):

from jax import random
import jax.numpy as jnp

# Fake word vectors (17 words vectors of dimension 32)
W = random.normal(random.PRNGKey(0), shape=(17, 32)) 

ws = 2          # window size
N = W.shape[0]  # number of words

new_W = jnp.zeros(W.shape)

for i in range(N):
    window = W[max(0, i-ws):min(N, i+ws+1)]
    n = window.shape[0]
    for j in range(n):
        new_W[i] += W[j] / n

我想有一个更快的解决方案,jnp.convolve但我不熟悉它。

4

1 回答 1

1

这看起来像您正在尝试进行卷积,因此jnp.convolve或类似的方法可能是一种性能更高的方法。

也就是说,您的示例有点奇怪,因为n它从不大于 4,因此除了W. 此外,您会覆盖内部循环的每次迭代中的前一个值,因此 的每一行new_W只包含前四行之一的缩放副本W

将您的代码更改为我认为您的意思,并使用index_update使其与 JAX 的不可变数组兼容,这样可以:

from jax import random
import jax.numpy as jnp

# Fake word vectors (17 words vectors of dimension 32)
W = random.normal(random.PRNGKey(0), shape=(17, 32)) 

ws = 2          # window size
N = W.shape[0]  # number of words

new_W = jnp.zeros(W.shape)

for i in range(N):
    window = W[max(0, i-ws):min(N, i+ws)]
    n = window.shape[0]
    for j in range(n):
      new_W = new_W.at[i].add(window[j] / n)

就更有效的卷积而言,这是等效的:

from jax.scipy.signal import convolve
kernel = jnp.ones((4, 1))
new_W_2 = convolve(W, kernel, mode='same') / convolve(jnp.ones_like(W), kernel, mode='same')

jnp.allclose(new_W, new_W_2)
# True
于 2021-06-09T16:50:11.883 回答