-1

我试图找到对大张量(> 200 行和列)执行路径相关更新的最有效的 tensorflow 方法。

解决方案需要是可区分的(并且可能与 xla 兼容)

我目前正在使用 tf.unstack,检查 for 循环中的每个张量并使用 tf.where 过滤掉我想要的条件。这非常慢并且导致许多张量操作


Bt = tf.ones([256])
Bt_n = tf.random_normal([200,256]) # would actually be calculated elsewhere
Mr = tf.random_normal([200,256])
Mp = tf.random_normal([200,256])

total = [Bt]

for mr, mp, n_Bt in zip(tf.unstack(Mr), 
                      tf.unstack(Mp),                                                      
                      tf.unstack(Bt_n)):
    Bt = tf.where(tf.logical_or(Bt <= mr, Bt >= mp), n_Bt, Bt)
    total.append(Bt)

final = tf.concat(total, axis=0)

只是寻找最有效(需要最少的操作)的方法来实现这一点。

谢谢。

4

1 回答 1

0

找到了答案 - 我需要使用 tf.scan

IE。

tf.scan(lambda a, x: tf.where(tf.logical_or(a <= x[0], a >= x[1]), x[2], a) , [Mr,Mp,Bt_n], initializer = Bt)
于 2019-06-11T06:15:52.797 回答