16

我正在研究 cs231n,我很难理解这个索引是如何工作的。鉴于

x = [[0,4,1], [3,2,4]]
dW = np.zeros(5,6)
dout = [[[  1.19034710e-01  -4.65005990e-01   8.93743168e-01  -9.78047129e-01
            -8.88672957e-01  -4.66605091e-01]
         [ -1.38617461e-03  -2.64569728e-01  -3.83712733e-01  -2.61360826e-01
            8.07072009e-01  -5.47607277e-01]
         [ -3.97087458e-01  -4.25187949e-02   2.57931759e-01   7.49565950e-01
           1.37707667e+00   1.77392240e+00]]

       [[ -1.20692745e+00  -8.28111550e-01   6.53041092e-01  -2.31247762e+00
         -1.72370321e+00   2.44308033e+00]
        [ -1.45191870e+00  -3.49328154e-01   6.15445782e-01  -2.84190582e-01
           4.85997687e-02   4.81590106e-01]
        [ -1.14828583e+00  -9.69055406e-01  -1.00773809e+00   3.63553835e-01
          -1.28078363e+00  -2.54448436e+00]]]

他们做的操作是

np.add.at(dW, x, dout)

x 是一个二维数组。索引如何在这里工作?我浏览了np.ufunc.at文档,但他们有简单的一维数组和常量示例:

np.add.at(a, [0, 1, 2, 2], 1)
4

3 回答 3

15
In [226]: x = [[0,4,1], [3,2,4]]
     ...: dW = np.zeros((5,6),int)

In [227]: np.add.at(dW,x,1)
In [228]: dW
Out[228]: 
array([[0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0]])

这样x就没有任何重复的条目,因此add.at与使用+=索引相同。等效地,我们可以通过以下方式读取更改的值:

In [229]: dW[x[0], x[1]]
Out[229]: array([1, 1, 1])

无论哪种方式,索引的工作方式都相同,包括广播:

In [234]: dW[...]=0
In [235]: np.add.at(dW,[[[1],[2]],[2,4,4]],1)
In [236]: dW
Out[236]: 
array([[0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 2, 0],
       [0, 0, 1, 0, 2, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]])

可能的值

broadcastable相对于索引,值必须是:

In [112]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)))
...
In [114]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)).ravel())
...
ValueError: array is not broadcastable to correct shape
In [115]: np.add.at(dW,[[[1],[2]],[2,4,4]],[1,2,3])

In [117]: np.add.at(dW,[[[1],[2]],[2,4,4]],[[1],[2]])

In [118]: dW
Out[118]: 
array([[ 0,  0,  0,  0,  0,  0],
       [ 0,  0,  3,  0,  9,  0],
       [ 0,  0,  4,  0, 11,  0],
       [ 0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0]])

在这种情况下,索引定义了 (2,3) 形状,因此 (2,3)、(3,)、(2,1) 和标量值起作用。(6,) 没有。

在这种情况下,add.at将 (2,3) 数组映射到 的 (2,2) 子数组dW

于 2017-08-03T04:13:58.293 回答
7

最近我也很难理解这行代码。希望我得到的可以帮助你,如果我错了,请纠正我。

这行代码中的三个数组如下:

x , whose shape is (N,T)
dW,  ---(V,D)
dout ---(N,T,D)

然后我们来到我们想弄清楚发生了什么的行代码

np.add.at(dW, x, dout)

如果你不想知道思维过程。上面的代码等价于:

for row in range(N):
   for col in range(T):
      dW[ x[row,col]  , :] += dout[row,col, :]

这是思考过程:

参考这个文档

https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ufunc.at.html

我们知道 x 是索引数组。所以关键是理解dW[x]。这是使用另一个数组(x)索引一个数组(dW)的概念。如果您不熟悉此概念,可以查看此链接

https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html

一般来说,使用索引数组时返回的是一个与索引数组具有相同形状的数组,但索引数组的类型和值。

dW[x] 将给我们一个形状为 (N,T,D) 的数组,其中 (N,T) 部分来自 x,而 (D) 部分来自 dW (V,D)。注意,x 的每个元素都在 [0, v) 的范围内。

让我们以一些数字为例

x:    np.array([[0,0],[0,0]]) ---- (2,2) N=2, T=2
dW:   np.array([[0,0],[2,2]]) ---- (2,2) V=2, D=2
dout: np.arange(1,9).reshape(2,2,2)  ----(2,2,2) N=2, T=2, D=2

dW[x] should be [ [[0 0] #this comes from the dW's firt row
                  [0 0]]

                  [[0 0]
                   [0 0]] ]

dW[x] add dout 表示添加 elemnet 项(这里,这是一些技巧,稍后会解释)

np.add.at(dW, x, dout) gives 
 [ [16 20]
   [ 2  2] ]

为什么?程序是:

它将 [1,2] 添加到 dW 的第一行,即 [0,0]。

为什么是第一排?因为x[0,0]=0,表示dW的第一行,dW[0]=dW[0,:]=第一行。

然后将 [3,4] 添加到 dW[0,0] 的第一行。[3,4]=dout[0,1,:]。[0,0] 再次来自 dW,x[0,1] = 0,仍然是 dW[0] 的第一行。

然后将 [5,6] 添加到 dW 的第一行。

然后将 [7,8] 添加到 dW 的第一行。

所以结果是 [1+3+5+7, 2+4+6+8] = [16,20]。因为我们没有触及 dW 的第二行。dW 的第二行保持不变。

诀窍是我们只会统计一次原点行,可以认为没有缓冲区,每一步都在原处播放。

于 2018-01-25T09:02:42.813 回答
0

让我们考虑一个基于 cs231n 分配的示例。如果我们谈论多个方向,则使用具体设置要容易得多。

np.random.seed(1)
N, T, V, D = 2, 3, 7, 6
x = np.random.randint(V, size=(N, T))
dW_man = np.zeros((V, D))

dW_man[x].shape, x.shape
((2, 3, 6), (2, 3))

x
array([[5, 3, 4],
   [0, 1, 3]])

dout = np.arange(2*3*6).reshape(dW_man[x].shape)
dout
array([[[ 0,  1,  2,  3,  4,  5],
    [ 6,  7,  8,  9, 10, 11],
    [12, 13, 14, 15, 16, 17]],

   [[18, 19, 20, 21, 22, 23],
    [24, 25, 26, 27, 28, 29],
    [30, 31, 32, 33, 34, 35]]])

应该是什么行dW_man[x]?Well[0, 1, ...]应该添加到第 5 行,[ 6, 7, ..]- 到第 3 行。也[30, 31, ...]应该添加到第 3 行。所以让我们手动计算它。在此 GitHub 要点中查看更多示例和说明:链接

dW_man[5] = dout[0, 0]
dW_man[3] = dout[0, 1]
dW_man[4] = dout[0, 2]

dW_man[0] = dout[1, 0]
dW_man[1] = dout[1, 1]
dW_man[3] = dout[1, 2]

dW_man
array([[18., 19., 20., 21., 22., 23.],
   [24., 25., 26., 27., 28., 29.],
   [ 0.,  0.,  0.,  0.,  0.,  0.],
   [30., 31., 32., 33., 34., 35.],
   [12., 13., 14., 15., 16., 17.],
   [ 0.,  1.,  2.,  3.,  4.,  5.],
   [ 0.,  0.,  0.,  0.,  0.,  0.]])

现在让我们使用np.add.at.

np.random.seed(1)
N, T, V, D = 2, 3, 7, 6
x = np.random.randint(V, size=(N, T))
dW = np.zeros((V, D))
dout = np.arange(2*3*6).reshape(dW[x].shape)
np.add.at(dW, x, dout)

dW
array([[18., 19., 20., 21., 22., 23.],
       [24., 25., 26., 27., 28., 29.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [36., 38., 40., 42., 44., 46.],
       [12., 13., 14., 15., 16., 17.],
       [ 0.,  1.,  2.,  3.,  4.,  5.],
       [ 0.,  0.,  0.,  0.,  0.,  0.]])
于 2021-11-04T10:40:52.700 回答