我最近开始尝试有趣的 python 库Jax,其中包含增强的 Numpy 和自动微分器。我想尝试创建的是一个粗略的“可微渲染器”,通过在 python 中编写着色器和损失函数,然后使用 Jax 的 AD 来查找渐变。然后我们应该能够通过在这个损失梯度上运行梯度下降来逆向渲染图像。我用简单的着色器让它工作得很好,但是当我使用布尔表达式时我遇到了问题。这是我的着色器的代码,它生成一个棋盘图案:
import jax.numpy as np
class CheckerShader:
def __init__(self, scale: float, color1: np.ndarray, color2: np.ndarray):
self.color1 = None
self.color2 = None
self.scale = None
self.scale_min = 0
self.scale_max = 20
self.color1 = color1
self.color2 = color2
self.scale = scale * 20
def checker(self, x: float, y: float) -> float:
xi = np.abs(np.floor(x))
yi = np.abs(np.floor(y))
first_col = np.mod(xi, 2) == np.mod(yi, 2)
return first_col
def shade(self, x: float, y: float):
x = x * self.scale
y = y * self.scale
first_col = self.checker(x, y)
if first_col:
return self.color1
else:
return self.color2
这是我的渲染函数,这是 JIT 失败的第一个地方:
import jax.numpy as np
import numpy as onp
import jax
def render(scale, c1, c2):
img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
sh = CheckerShader(scale, c1, c2)
jit_func = jax.jit(sh.shade)
for y in range(HEIGHT):
for x in range(WIDTH):
val = jit_func(x / WIDTH, y / HEIGHT)
img[y, x, :] = val
return img
我收到的错误信息是:
TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
我猜这是因为您不能在具有布尔值的函数上运行 JIT,其值取决于编译时未确定的内容。但是我怎样才能重写它以使用 JIT 呢?如果没有 JIT,它会非常缓慢。
我的另一个问题是,我能做些什么来加速 Jax 的 Numpy?使用普通 Numpy 渲染我的图像(100x100 像素)需要几毫秒,但使用 Jax 的 Numpy 需要几秒钟!感谢:D