0

我已经编写了一个框架,它以抽象的方式连接不同的(相当复杂的)线性运算符。它覆盖运算符“+、*、@、-”并选择通过函数组合图的路径。至少可以说调试并不容易,但是控制流不依赖于数据本身,当然任何操作都是用 tensorflow 完成的。我希望使用 tf.function 来编译它并通过 XLA 获得(希望更快) tf.function。但是我收到以下错误:

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: Reshape_2:0

我没有在任何地方使用 tf.init_scope 并且有 8 个(!)关于这个错误的谷歌结果 - 虽然它们都没有为我提供如何调试它的任何线索。

    # initilize linear operators, these are python objects that override __matmul__ etc.
    P = ...
    A = ...
    # initilize vectors, these are compatible python objects to P and A 
    x = ...
    y = ...

    # This function recreates the python object from its raw tensorflow data. 
    # Since it might be dependend on the spaces and
    # they also need to be set up for deserializaton the method is returned by a function of x.
    # But since many vectors share the same spaces I was hoping to reuse it.
    deserialize = x.deserialize()
    
    # We want to compile the action on x to a function now
    def meth( data ):
        result = P @ ( A.T @ A @ deserialize( data ) )
        # we return only the raw data
        return result.serialize()

    meth = tf.function( meth,
                       #experimental_compile = True ,
                       input_signature = (x.serialize_signature,),
    ).get_concrete_function()
    
    # we want to use meth now for many vectors
    # executing this line throws the error 
    meth(x1)
    meth(x2)
    meth(x3)
   

不用说,没有 tf.function 也可以。有没有人偶然发现错误并可以帮助我更好地理解它?还是我尝试的孔设置不适合 tensorflow ?

编辑:

错误是由局部 lambda 隐式捕获线性运算符类中的常量张量引起的。老实说,错误消息暗示了类似的内容,但是很难理解代码中的哪一行导致它,并且最终找到错误并不容易。

4

0 回答 0