4

我想训练一个简单的线性模型。x 和 y 下面的这些是我的数据。

import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)

f 是计算所有数据的均方误差的函数。

def f(params, x, y):
  return np.mean(np.power((params['w'] * x + params['b'])-y , 2))

from jax import grad
df = grad(f)
params = dict()
#initialize parameters
params['w'] = 2.4
params['b'] = 10.
df(params, x, y) # I will do this in a loop (implementing gradient decent part

这给了我一个错误:

FilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray

当我清除np.power代码时。为什么?

4

1 回答 1

4

JAX 不能计算函数的梯度numpy,但它可以计算函数的梯度jax.numpy。如果您根据 重写代码jax.numpy,它应该适合您:

import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)

import jax.numpy as jnp
def f(params, x, y):
  return jnp.mean(jnp.power((params['w'] * x + params['b'])-y , 2))

from jax import grad
df = grad(f)
params = dict()

params['w'] = 2.4
params['b'] = 10.
df(params, x, y)
# {'b': DeviceArray(14.661432, dtype=float32),
#  'w': DeviceArray(7.3792152, dtype=float32)}

TracerArrayConversionError您可以在文档页面中阅读更多详细信息。

于 2021-04-27T16:40:52.013 回答