1

我们正在尝试实现一个分段函数,基本上是大约 100 个具有不同系数的多项式,具体取决于 x 的值。

这将在 TensorFlow 或带有 JIT 的 jax 中实现,并针对数据数组进行优化。问题是实现这一目标的最佳方法可能是什么?

可以使用一百个 where,但这并不是最佳选择。或使用tf.switch_casewith tf.vectorize_map(或类似)。

有什么想法吗?

4

1 回答 1

0

如果我理解正确,我认为这jax.lax.switch提供了您感兴趣的功能。例如:

import jax.numpy as jnp
from jax import vmap, lax
import matplotlib.pyplot as plt

def f1(x):
  return 0.0 * x

def f2(x):
  return (x - 1.0) ** 2

def f3(x):
  return 2 * x - 3

branches = (f1, f2, f3)
bounds = jnp.array([1, 2])  # boundaries between branches

x = jnp.linspace(0, 3)
index = jnp.searchsorted(bounds, x)  # index in branches for each value in x

result = vmap(lambda i, x: lax.switch(i, branches, x))(index, x)
plt.plot(x, result)

在此处输入图像描述

于 2021-06-15T17:27:15.273 回答