TLDR:这取决于您的功能以及您是处于生产还是开发阶段。tf.function
如果您希望能够轻松调试您的函数,或者如果它受到 AutoGraph 或 tf.v1 代码兼容性的限制,请不要使用。我强烈建议观看 Inside TensorFlow 关于AutoGraph和Functions 的讨论,而不是 Sessions。
下面我将分解原因,这些都是从谷歌在线提供的信息中获取的。
通常,tf.function
装饰器会导致函数被编译为执行 TensorFlow 图的可调用对象。这需要:
- 如果需要,通过 AutoGraph 转换代码(包括从注释函数调用的任何函数)
- 跟踪并执行生成的图形代码
有关于这背后的设计理念的详细信息。
装饰函数的好处tf.function
一般福利
对于带有 Python 代码的函数/通过tf.function
装饰使用 AutoGraph
如果你想使用 AutoGraph,tf.function
强烈推荐使用而不是直接调用 AutoGraph。原因包括:自动控制依赖,某些 API 需要它,更多缓存和异常助手(来源)。
装饰函数的缺点tf.function
一般缺点
- 如果该函数仅包含少量昂贵的操作,则不会有太多的加速(来源)
对于带有 Python 代码的函数/通过tf.function
装饰使用 AutoGraph
- 没有异常捕获(应该在急切模式下完成;在装饰函数之外)(来源)
- 调试要困难得多
- 由于隐藏的副作用和 TF 控制流造成的限制
提供有关 AutoGraph 限制的详细信息。
对于具有 tf.v1 代码的函数
- 不允许在 中多次创建变量
tf.function
,但这可能会随着 tf.v1 代码的逐步淘汰而发生变化(来源)
对于带有 tf.v2 代码的函数
限制示例
多次创建变量
不允许多次创建变量,例如v
以下示例:
@tf.function
def f(x):
v = tf.Variable(1)
return tf.add(x, v)
f(tf.constant(2))
# => ValueError: tf.function-decorated function tried to create variables on non-first call.
在以下代码中,通过确保self.v
仅创建一次来缓解这种情况:
class C(object):
def __init__(self):
self.v = None
@tf.function
def f(self, x):
if self.v is None:
self.v = tf.Variable(1)
return tf.add(x, self.v)
c = C()
print(c.f(tf.constant(2)))
# => tf.Tensor(3, shape=(), dtype=int32)
AutoGraph 未捕捉到的隐藏副作用
self.a
无法隐藏此示例中的更改,这会导致错误,因为尚未完成跨功能分析(来源):
class C(object):
def change_state(self):
self.a += 1
@tf.function
def f(self):
self.a = tf.constant(0)
if tf.constant(True):
self.change_state() # Mutation of self.a is hidden
tf.print(self.a)
x = C()
x.f()
# => InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(), dtype=int32)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=cond_true_5, id=5477800528); accessed from: FuncGraph(name=f, id=5476093776).
显而易见的变化是没有问题的:
class C(object):
@tf.function
def f(self):
self.a = tf.constant(0)
if tf.constant(True):
self.a += 1 # Mutation of self.a is in plain sight
tf.print(self.a)
x = C()
x.f()
# => 1
TF 控制流的限制示例
这个 if 语句会导致错误,因为需要为 TF 控制流定义 else 的值:
@tf.function
def f(a, b):
if tf.greater(a, b):
return tf.constant(1)
# If a <= b would return None
x = f(tf.constant(3), tf.constant(2))
# => ValueError: A value must also be returned from the else branch. If a value is returned from one branch of a conditional a value must be returned from all branches.