1

我最近开始尝试有趣的 python 库Jax,其中包含增强的 Numpy 和自动微分器。我想尝试创建的是一个粗略的“可微渲染器”,通过在 python 中编写着色器和损失函数,然后使用 Jax 的 AD 来查找渐变。然后我们应该能够通过在这个损失梯度上运行梯度下降来逆向渲染图像。我用简单的着色器让它工作得很好,但是当我使用布尔表达式时我遇到了问题。这是我的着色器的代码,它生成一个棋盘图案:

import jax.numpy as np

class CheckerShader:

    def __init__(self, scale: float, color1: np.ndarray, color2: np.ndarray):
        self.color1 = None
        self.color2 = None
        self.scale = None
        self.scale_min = 0
        self.scale_max = 20
        self.color1 = color1
        self.color2 = color2
        self.scale = scale * 20

    def checker(self, x: float, y: float) -> float:
        xi = np.abs(np.floor(x))
        yi = np.abs(np.floor(y))

        first_col = np.mod(xi, 2) == np.mod(yi, 2)
        return first_col

    def shade(self, x: float, y: float):
        x = x * self.scale
        y = y * self.scale

        first_col = self.checker(x, y)
        if first_col:
            return self.color1
        else:
            return self.color2

这是我的渲染函数,这是 JIT 失败的第一个地方:

import jax.numpy as np
import numpy as onp
import jax

def render(scale, c1, c2):
    img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
    sh = CheckerShader(scale, c1, c2)
    jit_func = jax.jit(sh.shade)

    for y in range(HEIGHT):
        for x in range(WIDTH):
            val = jit_func(x / WIDTH, y / HEIGHT)
            img[y, x, :] = val

    return img

我收到的错误信息是:

TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

我猜这是因为您不能在具有布尔值的函数上运行 JIT,其值取决于编译时未确定的内容。但是我怎样才能重写它以使用 JIT 呢?如果没有 JIT,它会非常缓慢。

我的另一个问题是,我能做些什么来加速 Jax 的 Numpy?使用普通 Numpy 渲染我的图像(100x100 像素)需要几毫秒,但使用 Jax 的 Numpy 需要几秒钟!感谢:D

4

2 回答 2

3

代替

if first_col:
    return self.color1
else:
    return self.color2

return np.where(first_col, self.color1, self.color2)
于 2020-03-09T12:30:53.957 回答
2

但是我怎样才能重写它以使用 JIT 呢?

Ivo 在这里有一个很好的答案 - 只需使用np.where.

我的另一个问题是,我能做些什么来加速 Jax 的 Numpy?

速度慢的原因可能有三个。

首先是JITing的性质。第一次运行代码会很慢,但是如果多次运行相同的代码,速度应该会提高。如果可能的话,我也会尝试 JIT 整个渲染功能,如果您打算多次运行它。

第二个原因是 numpy 和 jax.numpy 之间的切换会很慢。你写

img = onp.zeros((WIDTH, HEIGHT, CHANNELS))

但如果你写它会快得多

img = np.zeros((WIDTH, HEIGHT, CHANNELS))

第三是您正在循环宽度和高度,而不是使用矢量化操作。我不明白为什么你不能以完全矢量化的形式做到这一点。

于 2020-05-31T18:51:42.273 回答