1

在 python 中,假设

a = np.array(range(0,12)).reshape(2,2,3)
b = np.array(range(0,6)).reshape(3,2)
c = np.matmul(a,b) // a @ b

我们有

a: array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]]])

b: array([[0, 1],
       [2, 3],
       [4, 5]])

c: array([[[10, 13],
        [28, 40]],

       [[46, 67],
        [64, 94]]])

有人可以帮助我在没有 for 循环的情况下在 java nd4j中实现等效操作吗?我试过broadcast.mul了,但结果broadcast.mul是元素乘法。我没有找到任何针对 mmul 的广播操作。

4

1 回答 1

2

我自己想通了。答案如下所示,以防有人需要。有了Nd4j.tensorMmul,就可以轻松实现矩阵广播。例如

val a = Nd4j.create(0d to 11d by 1d toArray, Array[Int](2, 2, 3))
val b = Nd4j.create(0d to 5d by 1d toArray, Array[Int](3, 2))
Nd4j.tensorMmul(a, b, Array(Array(2), Array(0))) // matrix broadcast

这是scala的代码。对于 java,您只需要更改代码即可创建数组。

于 2018-12-05T15:24:54.123 回答