我已经编写了一个框架,它以抽象的方式连接不同的(相当复杂的)线性运算符。它覆盖运算符“+、*、@、-”并选择通过函数组合图的路径。至少可以说调试并不容易,但是控制流不依赖于数据本身,当然任何操作都是用 tensorflow 完成的。我希望使用 tf.function 来编译它并通过 XLA 获得(希望更快) tf.function。但是我收到以下错误:
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: Reshape_2:0
我没有在任何地方使用 tf.init_scope 并且有 8 个(!)关于这个错误的谷歌结果 - 虽然它们都没有为我提供如何调试它的任何线索。
# initilize linear operators, these are python objects that override __matmul__ etc.
P = ...
A = ...
# initilize vectors, these are compatible python objects to P and A
x = ...
y = ...
# This function recreates the python object from its raw tensorflow data.
# Since it might be dependend on the spaces and
# they also need to be set up for deserializaton the method is returned by a function of x.
# But since many vectors share the same spaces I was hoping to reuse it.
deserialize = x.deserialize()
# We want to compile the action on x to a function now
def meth( data ):
result = P @ ( A.T @ A @ deserialize( data ) )
# we return only the raw data
return result.serialize()
meth = tf.function( meth,
#experimental_compile = True ,
input_signature = (x.serialize_signature,),
).get_concrete_function()
# we want to use meth now for many vectors
# executing this line throws the error
meth(x1)
meth(x2)
meth(x3)
不用说,没有 tf.function 也可以。有没有人偶然发现错误并可以帮助我更好地理解它?还是我尝试的孔设置不适合 tensorflow ?
编辑:
错误是由局部 lambda 隐式捕获线性运算符类中的常量张量引起的。老实说,错误消息暗示了类似的内容,但是很难理解代码中的哪一行导致它,并且最终找到错误并不容易。