10

有没有办法在TensorFlow中提取方阵的对角线?也就是说,对于这样的矩阵:

[
 [0, 1, 2],
 [3, 4, 5],
 [6, 7, 8]
]

我想获取元素:[0, 4, 8]

在 numpy 中,这通过np.diag非常简单:

在 TensorFlow 中,有一个diag 函数,但它仅与对角线上的参数中指定的元素形成一个新矩阵,这不是我想要的。

我可以想象如何通过跨步来做到这一点......但我没有看到 TensorFlow 中的张量跨步。

4

6 回答 6

10

使用 tensorflow 0.8 可以提取对角线元素tf.diag_part()(请参阅文档

更新

对于 tensorflow >= r1.12 它tf.linalg.tensor_diag_part(参见文档

于 2016-04-20T12:57:54.423 回答
4

目前可以使用tf.diag_part提取对角线元素。这是他们的例子:

"""
'input' is [[1, 0, 0, 0],
            [0, 2, 0, 0],
            [0, 0, 3, 0],
            [0, 0, 0, 4]]
"""

tf.diag_part(input) ==> [1, 2, 3, 4]

旧答案(当 diag_part 时)不可用(如果您想实现现在不可用的东西,仍然相关):

在查看了数学运算张量变换之后,看起来并不存在这样的运算。即使您可以使用矩阵乘法提取此数据,它也不会是有效的(得到对角线是O(n))。

你有三种方法,从易到难。

  1. 评估张量,用 numpy 提取对角线,用 TF 构建变量
  2. 以 Anurag 建议的方式使用tf.pack(也使用 3 提取值tf.shape
  3. 用 C++编写自己的操作,重建 TF 并在本地使用它。
于 2015-11-14T02:36:31.043 回答
3

使用 tf.diag_part()

with tf.Session() as sess:
    x = tf.ones(shape=[3, 3])
    x_diag = tf.diag_part(x)
    print(sess.run(x_diag ))
于 2016-12-29T03:11:32.043 回答
2

这可能是一种解决方法,但有效。

>> sess = tensorflow.InteractiveSession()
>> x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
>> x.initializer.run()
>> z = tensorflow.pack([x[i,i] for i in range(3)])
>> z.eval()
array([1, 5, 9], dtype=int32)
于 2015-11-14T01:03:21.160 回答
0

根据上下文,掩码可能是“取消”矩阵的对角元素的好方法,特别是如果您计划无论如何都减少它:

mask = tf.diag(tf.ones([n]))
y = tf.mul(mask,y)
cost = -tf.reduce_sum(y)
于 2016-01-16T20:17:48.563 回答
0

使用gather操作。

x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
x_flat = tf.reshape(x, [-1])  # flatten the matrix
x_diag = tf.gather(x, [0, 3, 6])
于 2016-01-05T05:44:41.157 回答