我有一个模型,我想计算它的梯度 wrt 输入。计算需要内存,因此我想将它分成批次。
由于我关心计算时间,我想将所有内容都包装在tf.function
.
这是我会做的一个例子:
import tensorflow as tf
import tensorflow_probability as tfp
def model(sample):
# Just a trivial example. This function in reality takes more arguments and creates
# a large computational graph
return tf.reduce_logsumexp(tfp.distributions.Normal(0., 1.).log_prob(sample), axis=1)
input_ = tf.random.uniform(maxval=1., shape=(100,10000000))
compiled_model = tf.function(model)
def get_batches(vars_, batch_size=10):
current_beginning = 0
all_elems = vars_[0].shape[0]
while current_beginning < all_elems:
yield tf.Variable(vars_[current_beginning:current_beginning+batch_size])
current_beginning += batch_size
res = []
for batch in get_batches(input_, batch_size=1):
with tf.GradientTape() as tape_logprob:
tape_logprob.watch(batch)
log_prob = compiled_model(batch)
res.append(tape_logprob.gradient(log_prob, batch))
如果您运行此代码,您会发现它会在 XLA 编译期间导致回溯,并严重影响性能:
WARNING:tensorflow:5 次调用 <function model at 0x7f945279b9d8> 中的 5 次触发了 tf.function 回溯。跟踪是昂贵的,过多的跟踪可能是由于(1)在循环中重复创建 @tf.function,(2)传递不同形状的张量,(3)传递 Python 对象而不是张量。对于 (1),请在循环之外定义您的 @tf.function。对于 (2),@tf.function 具有 Experimental_relax_shapes=True 选项,可以放宽可以避免不必要的回溯的参数形状。对于(3),更多细节请参考https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args和https://www.tensorflow.org/api_docs/python/tf/function。
我不明白为什么要在这里进行回溯。遵循警告中提到的几点:1)。我没有tf.function
在循环中定义(尽管我在循环中运行它)。2)。输入张量的形状总是相同的,因此我相信编译应该只发生一次。3)。我不使用普通的 Python 对象。
我在这里缺少什么细微差别?如何使这个例子工作?
在进行实验时,我注意到我可以通过将单个批次包装log_prob = compiled_model(batch)
成一个琐碎tf.map_fn
的内容来消除警告消息,但与非批处理版本的计算相比,我仍然观察到性能下降很大。