0

我有一个模型,我想计算它的梯度 wrt 输入。计算需要内存,因此我想将它分成批次。

由于我关心计算时间,我想将所有内容都包装在tf.function.

这是我会做的一个例子:

import tensorflow as tf
import tensorflow_probability as tfp

def model(sample):
    # Just a trivial example. This function in reality takes more arguments and creates
    # a large computational graph
    return tf.reduce_logsumexp(tfp.distributions.Normal(0., 1.).log_prob(sample), axis=1)

input_ = tf.random.uniform(maxval=1., shape=(100,10000000))

compiled_model = tf.function(model)

def get_batches(vars_, batch_size=10):
    current_beginning = 0
    all_elems = vars_[0].shape[0]
    while current_beginning < all_elems:
        yield tf.Variable(vars_[current_beginning:current_beginning+batch_size])
        current_beginning += batch_size

res = []        

for batch in get_batches(input_, batch_size=1):
    with tf.GradientTape() as tape_logprob:
        tape_logprob.watch(batch)
        log_prob = compiled_model(batch)
    
    res.append(tape_logprob.gradient(log_prob, batch))
    

如果您运行此代码,您会发现它会在 XLA 编译期间导致回溯,并严重影响性能:

WARNING:tensorflow:5 次调用 <function model at 0x7f945279b9d8> 中的 5 次触发了 tf.function 回溯。跟踪是昂贵的,过多的跟踪可能是由于(1)在循环中重复创建 @tf.function,(2)传递不同形状的张量,(3)传递 Python 对象而不是张量。对于 (1),请在循环之外定义您的 @tf.function。对于 (2),@tf.function 具有 Experimental_relax_shapes=True 选项,可以放宽可以避免不必要的回溯的参数形状。对于(3),更多细节请参考https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_argshttps://www.tensorflow.org/api_docs/python/tf/function

我不明白为什么要在这里进行回溯。遵循警告中提到的几点:1)。我没有tf.function在循环中定义(尽管我在循环中运行它)。2)。输入张量的形状总是相同的,因此我相信编译应该只发生一次。3)。我不使用普通的 Python 对象。

我在这里缺少什么细微差别?如何使这个例子工作?

在进行实验时,我注意到我可以通过将单个批次包装log_prob = compiled_model(batch)成一个琐碎tf.map_fn的内容来消除警告消息,但与非批处理版本的计算相比,我仍然观察到性能下降很大。

4

0 回答 0