1

因此,这里说间接修改不应该起作用,这意味着更改将是不可见的(无论如何,不​​可见的更改是什么意思?)

但是这段代码正确地计算了梯度:

import tensorflow as tf


class C:
    def __init__(self):
        self.x = tf.Variable(2.0)

    @tf.function
    def change(self):
        self.x.assign_add(2.0)

    @tf.function
    def func(self):
        self.change()
        return self.x * self.x


c = C()
with tf.GradientTape() as tape:
    y = c.func()
print(tape.gradient(y, c.x)) # --> tf.Tensor(8.0, shape=(), dtype=float32)

我在这里错过了什么吗?

谢谢

4

1 回答 1

1

文档缺少细节,应予以澄清 - “不可见”意味着 AutoGraph 的分析器未检测到更改。由于 AutoGraph 一次分析一个函数,因此分析仪看不到在另一个函数中所做的修改。

但是,此警告不适用于具有副作用的 Ops,例如对 TF 变量的修改——这些仍然会在图中正确连接。所以你的代码应该可以正常工作。

该限制仅适用于对纯 Python 对象(列表、字典等)所做的一些更改,并且仅在使用控制流时才会出现问题。

例如,这是对您的代码的修改,但该修改不起作用:

class C:
    def __init__(self):
        self.x = None

    def reset(self):
        self.x = tf.constant(10)

    def change(self):
        self.x += 1

    @tf.function
    def func(self):
      self.reset()
      for i in tf.range(3):
        self.change()
      return self.x * self.x


c = C()
print(c.func())

错误消息相当模糊,但如果您尝试访问在 a 的主体内创建的操作的结果tf.while_loop而不使用,则会引发相同的错误loop_vars

    <ipython-input-18-23f1641cfa01>:20 func  *
        return self.x * self.x

    ... more internal frames ...

    InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(),
dtype=int32)' cannot be accessed here: it is defined in another function or
code block. Use return values, explicit Python locals or TensorFlow
collections to access it. Defined in: FuncGraph(name=while_body_685,
id=5029696157776); accessed from: FuncGraph(name=func, id=5029690557264).
于 2019-12-07T12:50:59.700 回答