8

我正在尝试TF2/Keras按照官方 Keras 演练编写自己的训练循环。vanilla 版本就像一个魅力,但是当我尝试将@tf.function装饰器添加到我的训练步骤时,一些内存泄漏占用了我所有的内存并且我失去了对我机器的控制,有人知道发生了什么吗?

代码的重要部分如下所示:

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = siamese_network(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, siamese_network.trainable_weights)
    optimizer.apply_gradients(zip(grads, siamese_network.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function
def test_step(x, y):
    val_logits = siamese_network(x, training=False)
    val_acc_metric.update_state(y, val_logits)
    val_prec_metric.update_state(y_batch_val, val_logits)
    val_rec_metric.update_state(y_batch_val, val_logits)


for epoch in range(epochs):
        step_time = 0
        epoch_time = time.time()
        print("Start of {} epoch".format(epoch))
        for step, (x_batch_train, y_batch_train) in enumerate(train_ds):
            if step > steps_epoch:
                break
           
            loss_value = train_step(x_batch_train, y_batch_train)
        train_acc = train_acc_metric.result()
        train_acc_metric.reset_states()
        
        for val_step,(x_batch_val, y_batch_val) in enumerate(test_ds):
            if val_step>validation_steps:
                break
            test_step(x_batch_val, y_batch_val)
         
        val_acc = val_acc_metric.result()
        val_prec = val_prec_metric.result()
        val_rec = val_rec_metric.result()

        val_acc_metric.reset_states()
        val_prec_metric.reset_states()
        val_rec_metric.reset_states()

如果我对这些@tf.function行发表评论,则不会发生内存泄漏,但步骤时间慢了 3 倍。我的猜测是,不知何故,图表是在每个时期或类似的情况下再次创建的 bean,但我不知道如何解决它。

这是我正在关注的教程:https ://keras.io/guides/writing_a_training_loop_from_scratch/

4

1 回答 1

0

tl;博士;

TensorFlow 可能会为传递给修饰函数的每个唯一参数值集生成一个新图。确保您将一致形状的Tensor对象传递给test_steptrain_step不是 python 对象。

细节

这是在黑暗中刺伤。虽然我从未尝试过,但我确实在文档@tf.function中发现了以下警告:

tf.function 还将任何纯 Python 值视为不透明对象,并为它遇到的每组 Python 参数构建一个单独的图。

注意:将 python 标量或列表作为参数传递给 tf.function 将始终构建一个新图。为避免这种情况,请尽可能将数字参数作为张量传递

最后:

函数通过从输入的 args 和 kwargs 计算缓存键来确定是否重用跟踪的 ConcreteFunction。缓存键是根据函数调用的输入 args 和 kwargs 识别 ConcreteFunction 的键,根据以下规则(可能会更改):

  • 为 tf.Tensor 生成的关键是它的形状和 dtype。
  • 为 tf.Variable 生成的键是唯一的变量 id。
  • 为 Python 原语(如 int、float、str)生成的键是它的值。
  • 为嵌套的字典、列表、元组、命名元组和属性生成的键是叶子键的扁平元组(参见nest.flatten)。(由于这种扁平化,调用具有与跟踪期间使用的嵌套结构不同的嵌套结构的具体函数将导致 TypeError)。
  • 对于所有其他 Python 类型,键对于对象是唯一的。这样,一个函数或方法就可以独立地跟踪每个调用它的实例。

我从这一切中得到的是,如果您没有将大小一致的 Tensor 对象传递给您的@tf.function-ified 函数(也许您使用 Python 集合或原语),那么您很可能正在创建一个新的图形版本函数与您传入的每个不同的参数值。我猜这可能会产生您所看到的内存爆炸行为。我不知道你的test_dstrain_ds对象是如何创建的,但你可能想确保它们的创建方式enumerate(blah_ds)像教程中一样返回张量,或者至少在传递给你的test_steptrain_step函数之前将值转换为张量。

于 2021-05-08T20:46:32.143 回答