我需要能够在调试输出中清楚地区分张量。让我用一个示例问题来说明这一点:
import tensorflow as tf
def loss(x):
return x**2
x = tf.Variable(5,dtype=float)
y = tf.Variable(x,dtype=float)
print("x:", x)
print("y:", y)
with tf.GradientTape() as tape1:
z1 = loss(x)
grad_z1 = tape1.gradient(z1, [x])
with tf.GradientTape() as tape2:
z2 = loss(y)
grad_z2 = tape2.gradient(z2, [x]) # Variable should be y here!
print("grad_z1:", grad_z1)
print("grad_z2:", grad_z2)
输出是:
x: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
y: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
grad_z1: [<tf.Tensor: id=25, shape=(), dtype=float32, numpy=10.0>]
grad_z2: [None]
在这里,我试图获得一个简单损失函数相对于某个输入变量的梯度。在“z1”的情况下,该示例运行良好,因为存在从x
到的图形连接z1
。x
但是它在 z2 的情况下会中断,因为从到没有图形连接z2
。y
通过从 的值初始化一个新变量 ,此连接“意外”中断x
。在这个例子中问题很明显,但在我更复杂的实际代码中,意外替换像这样的变量更容易,从而破坏计算。然后我不得不四处寻找,试图找出我在哪里犯了这样的错误。
如果我可以检查张量并找出它们在哪里变成不同的对象,这个过程会容易得多。例如,是否有某种独特的 ID 属性或我可以检查的东西?在上面的示例中,我无法分辨,x
实际上y
是与打印输出完全不同的变量。它们看起来相同,但当然不是。
所以我需要其他可以打印的东西来帮助追踪x
意外被交换为y
. 有没有这样的财产?肯定有,但我找不到。也许我可以打印对象的内存地址或其他东西,但我也不确定如何在 Python 中做到这一点。