0

假设我有一个 64x64x64 的 3D 图像。我还有一个长度为 64 的向量 x。

我想像这样采用'argmax(x)'层:

2d_image = 3d_image[:,argmax(x),:]

更精确(对于张量流):

def extract_slice(x,3d_image):
     slice_index = tf.math.argmax(x,axis=1,output_type=tf.dtype.int32) #it has to be int for indexing
     return 3d_image[:,slice_index,:]

错误是:

Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got <tf.Tensor 'ArgMax_50:0' shape=(None,) dtype=int32>

参数的 np.shape 是:

3d_image 形状是(None, 64, 64, 64, 1)

x 形状是(None, 64)

slice_index 形状是(None,)

-> 3d_image 形状的 ,1 维是因为它是数组中的一个样本。我认为这无关紧要

我知道 None 形状是批量大小,这是未知的,但其他看起来很棒..那么问题是什么?

据我了解,看起来索引不是 int32,但实际上我确实将它转换为 tf.int 那么可能是什么问题?也许 int32 与 tf.int32 不同?或者我使用的索引方法在张量流中无效?也许它应该是一个类似的函数:tf.index(image,[:,slice_index,:])..?

谢谢!

4

1 回答 1

0

Argmax 返回一维张量。将其转换为标量:

 slice_index = tf.reshape(slice_index, ())
于 2020-12-23T19:51:44.880 回答