1

我正在尝试在 Python 中构建一个神经网络来解决 PDE,因此,我不得不编写自定义训练步骤。我的训练功能如下所示:

...
tf.enable_eager_execution()

class PDENet:

     ...
  
     def train_step():
          input = self.input
          
          with tf.GradientTape() as tape, tf.Session() as sess:
               tape.watch(input)
               
               output = self.model(input)
               self.loss = self.pde_loss(output) # (network does not use training data)
               
          grad = tape.gradient(self.loss, self.model.trainable_weights)
          self.optimizer.apply_gradients([(grad, self.model)])

     ...

由于我的硬件,我别无选择,只能使用 tensorflow==1.12.0 和 keras==2.2.4。
当我运行此代码时,我得到“RuntimeError: Attempting to capture an EagerTensor without building a function”。我看过其他关于此的帖子,但所有答案都说要更新 tensorflow/keras,我不能,使用我已经完成的“tf.enable_eager_execution()”和“tf.disable_v2_behavior()” ,这在旧版本的 tensorflow 上不存在。我还能做些什么来解决这个问题吗?该错误让我认为 tensorflow 想要我添加@tf.function,但 tensorflow 1 中似乎也不存在该功能。

4

0 回答 0