0

我正在使用 Jax 做一些机器学习工作。Jax 使用 XLA 进行一些即时编译以加速,但编译本身在 CPU 上太慢了。我的情况是 CPU 只会使用一个内核来进行编译,这根本没有效率。

我找到了一些答案,如果我可以使用 GPU 进行编译,它会非常快。谁能告诉我如何使用 GPU 来完成编译部分?由于我没有对编译进行任何配置。谢谢!

问题的一些补充:我正在使用 Jax 计算 grad 和 hessian,这会使编译非常慢。代码如下:

    ## get results from model ##
    def get_model_value(images):
        return jnp.sum(model(images))

    def get_model_grad(images):
        images = jnp.expand_dims(images, axis=0)
        image_grad = jacfwd(get_model_value)(images)
        return image_grad
    
    def get_model_hessian(images):
        images = jnp.expand_dims(images, axis=0)
        image_hess = jacfwd(jacrev(get_model_value))(images)
        return image_hess
  
    # get value
    model_value = model(dis_img)
    FR_value = jnp.expand_dims(FR_value, axis=1)
    value_loss = crit_mse(model_value, FR_value)
    
    # get grad
    vmap_model_grad = jax.vmap(get_model_grad)
    model_grad = vmap_model_grad(dis_img)
    
    # get hessian
    vmap_model_hessian = vmap(get_model_hessian)
    model_hessian = vmap_model_hessian(dis_img)
4

0 回答 0