4

例如,我需要创建一个像这样 (5,3,2) 的 3D 张量

array([[[0, 0],
        [0, 1],
        [0, 0]],

       [[1, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [1, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [1, 0]],

       [[0, 0],
        [0, 1],
        [0, 0]]])

每个切片中都应该随机放置一个“一”(如果您认为张量是一条面包)。这可以使用循环来完成,但我想对这部分进行矢量化。

4

3 回答 3

4

尝试生成一个随机数组,然后找到max

a = np.random.rand(5,3,2)
out = (a == a.max(axis=(1,2))[:,None,None]).astype(int)
于 2021-02-16T03:10:33.513 回答
2

最直接的方法可能是创建一个零数组,并将随机索引设置为 1。在 NumPy 中,它可能如下所示:

import numpy as np

K, M, N = 5, 3, 2
i = np.random.randint(0, M, K)
j = np.random.randint(0, N, K)
x = np.zeros((K, M, N))
x[np.arange(K), i, j] = 1

在 JAX 中,它可能看起来像这样:

import jax.numpy as jnp
from jax import random

K, M, N = 5, 3, 2
key1, key2 = random.split(random.PRNGKey(0))
i = random.randint(key1, (K,), 0, M)
j = random.randint(key2, (K,), 0, N)
x = jnp.zeros((K, M, N)).at[jnp.arange(K), i, j].set(1)

一个更简洁的选项,也保证1每个切片一个,是使用具有适当构造的范围的随机整数的广播相等性:

r = random.randint(random.PRNGKey(0), (K, 1, 1), 0, M * N)
x = (r == jnp.arange(M * N).reshape(M, N)).astype(int)
于 2021-02-16T05:07:35.780 回答
0

您可以创建一个零数组,其中每个子数组的第一个元素为 1,然后permute跨最后两个轴:

x = np.zeros((5,3,2)); x[:,0,0] = 1

rng = np.random.default_rng()
x = rng.permuted(rng.permuted(x, axis=-1), axis=-2)
于 2021-04-01T15:23:27.950 回答