import tensorflow as tf
import numpy as np
x = tf.Variable(2, name='x', trainable=True, dtype=tf.float32)
with tf.GradientTape() as t:
t.watch(x)
log_x = tf.math.log(x)
y = tf.math.square(log_x)
opt = tf.optimizers.Adam(0.5)
# train = opt.minimize(lambda: y, var_list=[x]) # FAILS
@tf.function
def f(x):
log_x = tf.math.log(x)
y = tf.math.square(log_x)
return y
yy = f(x)
train = opt.minimize(lambda: yy, var_list=[x]) # ALSO FAILS
产生值错误:
No gradients provided for any variable: ['x:0'].
这看起来像他们部分给出的例子。我不确定这是 Eager 或 2.0 的错误还是我做错了什么。
更新:
由于存在一些问题和有趣的注释,因此在下面粘贴了解决方案的修饰版本。
import numpy as np
import tensorflow as tf
x = tf.Variable(3, name='x', trainable=True, dtype=tf.float32)
with tf.GradientTape(persistent=True) as t:
# log_x = tf.math.log(x)
# y = tf.math.square(log_x)
y = (x - 1) ** 2
opt = tf.optimizers.Adam(learning_rate=0.001)
def get_gradient_wrong(x0):
# this does not work, it does not actually update the value of x
x.assign(x0)
return t.gradient(y, [x])
def get_gradient(x0):
# this works
x.assign(x0)
with tf.GradientTape(persistent=True) as t:
y = (x - 1) ** 2
return t.gradient(y, [x])
#### Option 1
def a(x0, tol=1e-8, max_iter=10000):
# does not appear to work properly
x.assign(x0)
err = np.Inf # step error (banach), not actual erro
i = 0
while err > tol:
x0 = x.numpy()
# IMPORTANT: WITHOUT THIS INSIDE THE LOOP THE GRADIENTS DO NOT UPDATE
with tf.GradientTape(persistent=True) as t:
y = (x - 1) ** 2
gradients = t.gradient(y, [x])
l = opt.apply_gradients(zip(gradients, [x]))
err = np.abs(x.numpy() - x0)
print(err, x.numpy(), gradients[0].numpy())
i += 1
if i > max_iter:
print(f'stopping at max_iter={max_iter}')
return x.numpy()
print(f'stopping at err={err}<{tol}')
return x.numpy()
#### Option 2
def b(x0, tol=1e-8, max_iter=10000):
x.assign(x0)
# To use minimize you have to define your loss computation as a funcction
def compute_loss():
log_x = tf.math.log(x)
y = tf.math.square(log_x)
return y
err = np.Inf # step error (banach), not actual erro
i = 0
while err > tol:
x0 = x.numpy()
train = opt.minimize(compute_loss, var_list=[x])
err = np.abs(x.numpy() - x0)
print(err, x.numpy())
i += 1
if i > max_iter:
print(f'stopping at max_iter={max_iter}')
return x.numpy()
print(f'stopping at err={err}<{tol}')
return x.numpy()