1

基本上,我想知道下面的代码是否安全。我想用来tf.py_function调用一些 scipy 代码,我想在其中评估调用会话的操作(包括设置变量)。原因是,我想使用一些 scipy 代码来做一些在 tensorflow 中难以编码的事情,但我不想中断计算图。我至少有两个例子:一个是在图中使用 LSODE 隐式 ODE 求解器,另一个是在图中使用 scipy 的强大最小化器(我正在考虑在需要重复求解 GP 的算法中调用 py_function 包装器内的 GPFlow 优化器优化问题)。

以下代码运行并返回我所期望的(值 1)。但我不知道它是否安全。我猜如果我还要使用x图中其他地方的值,它会给出不确定的行为。

from scipy.optimize import minimize
import tensorflow as tf

def build_func(session, y_and_grad, var):
    pl = tf.placeholder(tf.float32)
    assign = tf.assign(var, pl)
    y, grad = y_and_grad
    def func(x):
        """
        This could be something that uses ops and vars in the same graph,
        but requires iterative access to session calls.
        E.g. using scipy.minimize with tensorflow to compute the gradients.
        """
        x = x.numpy()
        def fun_and_jac(x):
            session.run(assign,{pl:x})
            y, jac = session.run(y_and_grad)
            return y, jac
        res = minimize(fun_and_jac, x0=x, jac=True)
        return res.x
    return func

with tf.Session(graph=tf.Graph()) as session:
    x = tf.Variable([0.], dtype=tf.float32)
    session.run(x.initializer)
    y = (x - 1.)**2
    grad = tf.gradients(y, x)[0]
    func = build_func(session, [y, grad], x)
    opt_x = tf.py_function(func, [x], [tf.float32])
    print(session.run(opt_x))
4

0 回答 0