The link you references really only allows you to get access to the data still sitting on the GPU, but using that data in another framework, like Tensorflow or PyTorch is not that simple.
TL;DR; Unless you have a library explicitly setup to work with the RAPIDS accelerator you probably want to run your ETL with RAPIDS, then save it, and launch a new job to train your models using that data.
There are still a number of issues that you would need to solve. We have worked on these in the case of XGBoost, but it has not been something that we have tried to tackle for Tensorflow or PyTorch yet.
The big issues are
- Getting the data to the correct process. Even if the data is on the GPU, because of security, it is tied to a given user process. PyTorch and Tensorflow generally run as python processes and not in the same JVM that Spark is running in. This means that the data has to be sent to the other process. There are several ways to do this, but it is non-trivial to try and do it as a zero-copy operation.
- The format of the data is not what Tensorflow or PyTorch want. The data for RAPIDs is in an arrow compatible format. Tensorflow and PyTorch have APIs for importing data in standard formats from the CPU, but it might take a bit of work to get the data into a format that the frameworks want and to find an API to let you pull it in directly from the GPU.
- Sharing GPU resources. Spark only recently added in support for scheduling GPUs. Prior to that people would just launch a single spark task per executor and a single python process so that the python process would own the entire GPU when doing training or inference. With the RAPIDS accelerator the GPU is not free any more and you need a way to share the resources. RMM provides some of this if both libraries are updated to use it and they are in the same process, but in the case of Pytorch and and Tensoflow they are typically in python processes so figuring out how to share the GPU is hard.