我们正在尝试实现一个分段函数,基本上是大约 100 个具有不同系数的多项式,具体取决于 x 的值。
这将在 TensorFlow 或带有 JIT 的 jax 中实现,并针对数据数组进行优化。问题是实现这一目标的最佳方法可能是什么?
可以使用一百个 where,但这并不是最佳选择。或使用tf.switch_case
with tf.vectorize_map
(或类似)。
有什么想法吗?
我们正在尝试实现一个分段函数,基本上是大约 100 个具有不同系数的多项式,具体取决于 x 的值。
这将在 TensorFlow 或带有 JIT 的 jax 中实现,并针对数据数组进行优化。问题是实现这一目标的最佳方法可能是什么?
可以使用一百个 where,但这并不是最佳选择。或使用tf.switch_case
with tf.vectorize_map
(或类似)。
有什么想法吗?
如果我理解正确,我认为这jax.lax.switch
提供了您感兴趣的功能。例如:
import jax.numpy as jnp
from jax import vmap, lax
import matplotlib.pyplot as plt
def f1(x):
return 0.0 * x
def f2(x):
return (x - 1.0) ** 2
def f3(x):
return 2 * x - 3
branches = (f1, f2, f3)
bounds = jnp.array([1, 2]) # boundaries between branches
x = jnp.linspace(0, 3)
index = jnp.searchsorted(bounds, x) # index in branches for each value in x
result = vmap(lambda i, x: lax.switch(i, branches, x))(index, x)
plt.plot(x, result)