我对 Jax 文档感到困惑,这就是我想要做的:

def line(m,x,b):
  return m*x + b



FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-48-d14b17620b30> in <module>()
----> 4 grad(line)(1,2,3)

FilteredStackTrace: TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int32. If you want to use integer-valued inputs, use vjp or set allow_int to True.

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int32. If you want to use integer-valued inputs, use vjp or set allow_int to True.


import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)


W_grad [-0.16965576 -0.8774648  -1.4901345 ]



Jax 告​​诉你它不喜欢整数。grad(line)(1.,2.,3.)(使用浮动)解决了这个问题。

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int32. If you want to use integer-valued inputs, use vjp or set allow_int to True.

要使用grad(line)(1,2,3)with Int32,请将其更改为grad(line, allow_int=True)(1,2,3)

