input_mb = tf.placeholder(tf.int32, [None, 166, 1], name="input_minibatch")
假设有上面的代码。我想获得上述小批量张量的行,使得每个检索到的行的第一个元素 == a。我如何在 Tensorflow 中做到这一点?另外,你如何在 Numpy 中做到这一点?
input_mb = tf.placeholder(tf.int32, [None, 166, 1], name="input_minibatch")
假设有上面的代码。我想获得上述小批量张量的行,使得每个检索到的行的第一个元素 == a。我如何在 Tensorflow 中做到这一点?另外,你如何在 Numpy 中做到这一点?
(给定一个值 a)
要在numpy中实现这一点,您只需编写:
selected_rows = myarray[myarray[:,0]== a]
在tensorflow中,使用 tf.where :
mytensor[tf.squeeze(tf.where(tf.equal(mytensor[:,0],a), None, None))
我会在 tensorflow 上这样做:
tf.gather(mytensor, tf.squeeze(tf.where(tf.equal(mytensor[:,0],a), None, None)), axis=0)