基本上,我想知道下面的代码是否安全。我想用来tf.py_function
调用一些 scipy 代码,我想在其中评估调用会话的操作(包括设置变量)。原因是,我想使用一些 scipy 代码来做一些在 tensorflow 中难以编码的事情,但我不想中断计算图。我至少有两个例子:一个是在图中使用 LSODE 隐式 ODE 求解器,另一个是在图中使用 scipy 的强大最小化器(我正在考虑在需要重复求解 GP 的算法中调用 py_function 包装器内的 GPFlow 优化器优化问题)。
以下代码运行并返回我所期望的(值 1)。但我不知道它是否安全。我猜如果我还要使用x
图中其他地方的值,它会给出不确定的行为。
from scipy.optimize import minimize
import tensorflow as tf
def build_func(session, y_and_grad, var):
pl = tf.placeholder(tf.float32)
assign = tf.assign(var, pl)
y, grad = y_and_grad
def func(x):
"""
This could be something that uses ops and vars in the same graph,
but requires iterative access to session calls.
E.g. using scipy.minimize with tensorflow to compute the gradients.
"""
x = x.numpy()
def fun_and_jac(x):
session.run(assign,{pl:x})
y, jac = session.run(y_and_grad)
return y, jac
res = minimize(fun_and_jac, x0=x, jac=True)
return res.x
return func
with tf.Session(graph=tf.Graph()) as session:
x = tf.Variable([0.], dtype=tf.float32)
session.run(x.initializer)
y = (x - 1.)**2
grad = tf.gradients(y, x)[0]
func = build_func(session, [y, grad], x)
opt_x = tf.py_function(func, [x], [tf.float32])
print(session.run(opt_x))