我正在使用 Jax 做一些机器学习工作。Jax 使用 XLA 进行一些即时编译以加速,但编译本身在 CPU 上太慢了。我的情况是 CPU 只会使用一个内核来进行编译,这根本没有效率。
我找到了一些答案,如果我可以使用 GPU 进行编译,它会非常快。谁能告诉我如何使用 GPU 来完成编译部分?由于我没有对编译进行任何配置。谢谢!
问题的一些补充:我正在使用 Jax 计算 grad 和 hessian,这会使编译非常慢。代码如下:
## get results from model ##
def get_model_value(images):
return jnp.sum(model(images))
def get_model_grad(images):
images = jnp.expand_dims(images, axis=0)
image_grad = jacfwd(get_model_value)(images)
return image_grad
def get_model_hessian(images):
images = jnp.expand_dims(images, axis=0)
image_hess = jacfwd(jacrev(get_model_value))(images)
return image_hess
# get value
model_value = model(dis_img)
FR_value = jnp.expand_dims(FR_value, axis=1)
value_loss = crit_mse(model_value, FR_value)
# get grad
vmap_model_grad = jax.vmap(get_model_grad)
model_grad = vmap_model_grad(dis_img)
# get hessian
vmap_model_hessian = vmap(get_model_hessian)
model_hessian = vmap_model_hessian(dis_img)