我试图找到对大张量(> 200 行和列)执行路径相关更新的最有效的 tensorflow 方法。
解决方案需要是可区分的(并且可能与 xla 兼容)
我目前正在使用 tf.unstack,检查 for 循环中的每个张量并使用 tf.where 过滤掉我想要的条件。这非常慢并且导致许多张量操作
Bt = tf.ones([256])
Bt_n = tf.random_normal([200,256]) # would actually be calculated elsewhere
Mr = tf.random_normal([200,256])
Mp = tf.random_normal([200,256])
total = [Bt]
for mr, mp, n_Bt in zip(tf.unstack(Mr),
tf.unstack(Mp),
tf.unstack(Bt_n)):
Bt = tf.where(tf.logical_or(Bt <= mr, Bt >= mp), n_Bt, Bt)
total.append(Bt)
final = tf.concat(total, axis=0)
只是寻找最有效(需要最少的操作)的方法来实现这一点。
谢谢。