2

我正在尝试使用 optuna 搜索超参数空间。

在一个特定的场景中,我在一台带有几个 GPU 的机器上训练一个模型。模型和批量大小允许我每 1 个 GPU 运行 1 次训练。因此,理想情况下,我想让 optuna 将所有试验分布在可用的 GPU 上,以便每个 GPU 上始终运行 1 个试验。

在它说的文档中,我应该在一个单独的终端中为每个 GPU 启动一个进程,例如:

CUDA_VISIBLE_DEVICES=0 optuna study optimize foo.py objective --study foo --storage sqlite:///example.db

我想避免这种情况,因为在那之后整个超参数搜索会继续进行多轮。我不想总是手动启动每个 GPU 的进程,检查所有进程何时完成,然后开始下一轮。

我看到study.optimize有一个n_jobs说法。乍一看,这似乎是完美的。 例如我可以这样做:

import optuna

def objective(trial):
    # the actual model would be trained here
    # the trainer here would need to know which GPU
    # it should be using
    best_val_loss = trainer(**trial.params)
    return best_val_loss

study = optuna.create_study()
study.optimize(objective, n_trials=100, n_jobs=8)

这会启动多个线程,每个线程都开始训练。但是,内部的培训师objective不知何故需要知道它应该使用哪个 GPU。有什么诀窍可以做到这一点吗?

4

2 回答 2

6

经过几次精神崩溃后,我发现我可以使用multiprocessing.Queue. 要将其纳入目标函数,我需要将其定义为 lambda 函数或类(我猜部分也可以)。例如

from contextlib import contextmanager
import multiprocessing
N_GPUS = 2

class GpuQueue:

    def __init__(self):
        self.queue = multiprocessing.Manager().Queue()
        all_idxs = list(range(N_GPUS)) if N_GPUS > 0 else [None]
        for idx in all_idxs:
            self.queue.put(idx)

    @contextmanager
    def one_gpu_per_process(self):
        current_idx = self.queue.get()
        yield current_idx
        self.queue.put(current_idx)


class Objective:

    def __init__(self, gpu_queue: GpuQueue):
        self.gpu_queue = gpu_queue

    def __call__(self, trial: Trial):
        with self.gpu_queue.one_gpu_per_process() as gpu_i:
            best_val_loss = trainer(**trial.params, gpu=gpu_i)
            return best_val_loss

if __name__ == '__main__':
    study = optuna.create_study()
    study.optimize(Objective(GpuQueue()), n_trials=100, n_jobs=8)
于 2020-05-14T14:01:37.583 回答
0

如果您想要将参数传递给多个作业使用的目标函数的文档化解决方案,那么 Optuna文档提供了两种解决方案:

  • 可调用的类(它可以与多处理结合),
  • lambda 函数包装器(注意:更简单,但不适用于多处理)。

如果您准备采取一些捷径,那么您可以通过将全局值(常量,例如使用的 GPU 数量)直接(通过 python 环境)传递给__call__()方法(而不是作为 的参数__init__())来跳过一些样板。

可调用类解决方案经过测试可optuna==2.0.0与两个多处理后端(loky/multiprocessing)和远程数据库后端(mariadb/postgresql)一起工作(in)。

于 2020-08-29T11:53:01.773 回答