0

我想知道如何tf.argmax在 3D 数组中使用。

我的输入数据是这样的:

[[[0, -1, 5, 2, 1], [2, 2, 3, 2, 5], [6, 1, 2, 4, -1]], 
[[-1, -2, 3, 2, 1], [0, 3, 2, 7, -1], [-1, 5, 2, 1, 3]]]

我想通过这个输入数据得到 argmax 的输出,如下所示:

[[2, 4, 0], [2, 3, 1]]

我想以softmax_cross_entropy_with_logits这种格式使用函数。

我应该如何使用tf.nn.softmax_cross_entropy_with_logits函数和tf.equal(tf.argmax)tf.reduce_mean(tf.cast)

4

1 回答 1

0

你可以tf.argmax一起使用axis=3

a = tf.constant([[[0, -1, 5, 2, 1], [2, 2, 3, 2, 5], [6, 1, 2, 4, -1]], 
        [[-1, -2, 3, 2, 1], [0, 3, 2, 7, -1], [-1, 5, 2, 1, 3]]])
b = tf.argmax(a, axis=2)
于 2017-09-01T06:29:13.763 回答