有没有办法在TensorFlow中提取方阵的对角线?也就是说,对于这样的矩阵:
[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
]
我想获取元素:[0, 4, 8]
在 numpy 中,这通过np.diag非常简单:
在 TensorFlow 中,有一个diag 函数,但它仅与对角线上的参数中指定的元素形成一个新矩阵,这不是我想要的。
我可以想象如何通过跨步来做到这一点......但我没有看到 TensorFlow 中的张量跨步。
目前可以使用tf.diag_part提取对角线元素。这是他们的例子:
"""
'input' is [[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]]
"""
tf.diag_part(input) ==> [1, 2, 3, 4]
旧答案(当 diag_part 时)不可用(如果您想实现现在不可用的东西,仍然相关):
在查看了数学运算和张量变换之后,看起来并不存在这样的运算。即使您可以使用矩阵乘法提取此数据,它也不会是有效的(得到对角线是O(n)
)。
你有三种方法,从易到难。
使用 tf.diag_part()
with tf.Session() as sess:
x = tf.ones(shape=[3, 3])
x_diag = tf.diag_part(x)
print(sess.run(x_diag ))
这可能是一种解决方法,但有效。
>> sess = tensorflow.InteractiveSession()
>> x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
>> x.initializer.run()
>> z = tensorflow.pack([x[i,i] for i in range(3)])
>> z.eval()
array([1, 5, 9], dtype=int32)
根据上下文,掩码可能是“取消”矩阵的对角元素的好方法,特别是如果您计划无论如何都减少它:
mask = tf.diag(tf.ones([n]))
y = tf.mul(mask,y)
cost = -tf.reduce_sum(y)
使用gather
操作。
x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
x_flat = tf.reshape(x, [-1]) # flatten the matrix
x_diag = tf.gather(x, [0, 3, 6])