2

我们目前正在将 Dask Gateway 与仅 CPU 工作人员一起使用。然而,随着深度学习被更广泛地采用,我们希望过渡到为通过 Dask Gateway 创建的集群添加 GPU 支持。

我查看了 Dask Gateway 文档,关于如何设置它以及我们需要更改 helm chart/config 的哪些部分以启用此功能的详细说明并没有太多。

我的想法是首先在 GCP 上的 GKE 集群中添加一个 GPU,然后为使用该 GPU 的 dask 工作人员使用 RAPIDS dockerfile?这就是 Dask Gateway 所需的全部设置吗?

如果有人能指出我正确的方向,将不胜感激。

4

1 回答 1

2

要在支持 GPU 计算的 Kubernetes 上运行 Dask 集群,您需要以下内容:

  • Kubernetes 节点需要 GPU 和驱动程序。这可以通过NVIDIA k8s 设备插件进行设置。
  • 调度程序和工作 pod 将需要安装了 NVIDIA 工具的 Docker 映像。正如您所建议的那样,RAPIDS 图像对此很有用。
  • pod 容器规范将需要 GPU 资源,例如resources.limits.nvidia.com/gpu: 1
  • Dask 工作人员需要使用包中的dask-cuda-worker命令dask_cuda(包含在 RAPIDS 映像中)启动。

注意: 对于 Dask Gateway,您的容器映像还需要dask-gateway安装包。我们可以将其配置为在运行时安装,但最好创建一个安装了此包的自定义映像。

因此,这里有一个最小的 Dask Gateway 配置,它将为您提供一个 GPU 集群。

# config.yaml
gateway:
  backend:
    image:
      name: rapidsai/rapidsai
      tag: cuda11.0-runtime-ubuntu18.04-py3.8  # Be sure to match your k8s CUDA version and user's Python version

    worker:
      extraContainerConfig:
        env:
          - name: EXTRA_PIP_PACKAGES
            value: "dask-gateway"
        resources:
          limits:
            nvidia.com/gpu: 1  # This could be >1, you will get one worker process in the pod per GPU

    scheduler:
      extraContainerConfig:
        env:
          - name: EXTRA_PIP_PACKAGES
            value: "dask-gateway"
        resources:
          limits:
            nvidia.com/gpu: 1  # The scheduler requires a GPU in case of accidental deserialisation

  extraConfig:
    cudaworker: |
      c.ClusterConfig.worker_cmd = "dask-cuda-worker"

我们可以通过启动 Dask 网关、创建 Dask 集群并运行一些 GPU 特定的工作来测试事情是否正常。这是一个示例,我们从每个工作人员那里获取 NVIDIA 驱动程序版本。

$ helm install dgwtest daskgateway/dask-gateway -f config.yaml
In [1]: from dask_gateway import Gateway

In [2]: gateway = Gateway("http://dask-gateway-service")

In [3]: cluster = gateway.new_cluster()

In [4]: cluster.scale(1)

In [5]: from dask.distributed import Client

In [6]: client = Client(cluster)

In [7]: def get_nvidia_driver_version():
   ...:     import pynvml
   ...:     return pynvml.nvmlSystemGetDriverVersion()
   ...: 

In [9]: client.run(get_nvidia_driver_version)
Out[9]: {'tls://10.42.0.225:44899': b'450.80.02'}
于 2020-12-09T10:05:09.807 回答