我试图区分一个函数,该函数在给定偏移平均值的情况下近似包含在 2 个限制(截断的高斯)内的高斯分数。
jnp.grad
不允许我区分添加布尔过滤器(注释行),所以我不得不即兴使用 sigmoid。
但是,现在当截断边界很高时梯度总是 nan 我不明白为什么。
在下面的示例中,我正在计算具有 0 均值和 std=1 的高斯梯度,然后我用x
.
如果我减小边界,则函数的行为符合预期。但这不是解决方案。当边界很高时,始终belows
变为 1。但如果是这种情况并且x
对下面没有影响,那么它对梯度的贡献应该是0而不是nan。但如果我返回belows[0][0]
而不是返回jnp.mean(filt, axis=0)
,我仍然得到nan
。
有任何想法吗?提前致谢(github上也有一个问题)
import os
from tqdm import tqdm
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' # Use 8 CPU devices
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
from jax import vmap
from functools import reduce
def sigmoid(x, scale=100):
return 1 / (1 + jnp.exp(-x*scale))
def above_lower(x, l, scale=100):
return sigmoid(x - l, scale)
def below_upper(x, u, scale=100):
return 1 - sigmoid(x - u, scale)
def combine_soft_filters(a):
return jnp.prod(jnp.stack(a), axis=0)
def fraction_not_truncated(mu, v, limits, stdnorm_samples):
L = jnp.linalg.cholesky(v)
y = vmap(lambda x: jnp.dot(L, x))(stdnorm_samples) + mu
# filt = reduce(jnp.logical_and, [(y[..., i] > l) & (y[..., i] < u) for i, (l, u) in enumerate(limits)])
aboves = [above_lower(y[..., i], l) for i, (l, u) in enumerate(limits)]
belows = [below_upper(y[..., i], u) for i, (l, u) in enumerate(limits)]
filt = combine_soft_filters(aboves+belows)
return jnp.mean(filt, axis=0)
limits = np.array([
[0.,1000],
])
stdnorm_samples = np.random.multivariate_normal([0], np.eye(1), size=1000)
def func(x):
return fraction_not_truncated(jnp.zeros(1)+x, jnp.eye(1), limits, stdnorm_samples)
_x = np.linspace(-2, 2, 500)
gradfunc = jax.grad(func)
vals = [func(x) for x in tqdm(_x)]
grads = [gradfunc(x) for x in tqdm(_x)]
print(vals)
print(grads)
import matplotlib.pyplot as plt
plt.plot(_x, np.asarray(vals))
plt.ylabel('f(x)')
plt.twinx()
plt.plot(_x, np.asarray(grads), c='r')
plt.ylabel("f(x)'")
plt.title('Fraction not truncated')
plt.axhline(0, color='k', alpha=0.2)
plt.xlabel('shift')
plt.tight_layout()
plt.show()
[DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64)]
[DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64)]