我正在尝试在 Python 中构建一个神经网络来解决 PDE,因此,我不得不编写自定义训练步骤。我的训练功能如下所示:
...
tf.enable_eager_execution()
class PDENet:
...
def train_step():
input = self.input
with tf.GradientTape() as tape, tf.Session() as sess:
tape.watch(input)
output = self.model(input)
self.loss = self.pde_loss(output) # (network does not use training data)
grad = tape.gradient(self.loss, self.model.trainable_weights)
self.optimizer.apply_gradients([(grad, self.model)])
...
由于我的硬件,我别无选择,只能使用 tensorflow==1.12.0 和 keras==2.2.4。
当我运行此代码时,我得到“RuntimeError: Attempting to capture an EagerTensor without building a function”。我看过其他关于此的帖子,但所有答案都说要更新 tensorflow/keras,我不能,使用我已经完成的“tf.enable_eager_execution()”和“tf.disable_v2_behavior()” ,这在旧版本的 tensorflow 上不存在。我还能做些什么来解决这个问题吗?该错误让我认为 tensorflow 想要我添加@tf.function
,但 tensorflow 1 中似乎也不存在该功能。