4

我正在使用 optuna 对我的自定义模型进行参数优化。

有没有办法对参数进行采样,直到当前的参数集之前没有测试过?我的意思是,如果过去有一些使用相同参数集的试验,请尝试对另一个参数进行采样。

在某些情况下,这是不可能的,例如,当存在分类分布并且n_trials大于可能的唯一采样值的数量时。

我想要的是:有一些配置参数num_attempts,以便num_attempts在 for 循环中对参数进行采样,直到有一个之前没有测试过的集合,否则 - 在最后一个采样的集合上运行试验。

为什么我需要这个:仅仅因为在相同参数上多次运行重型模型的成本太高。

我现在做什么:只做这个“for-loop”的东西,但它很乱。

如果有另一种聪明的方式来做到这一点 - 将非常感谢信息。

谢谢!

4

2 回答 2

9

据我所知,目前没有直接的方法可以处理您的案件。作为一种解决方法,您可以检查参数重复并跳过评估,如下所示:

import optuna

def objective(trial: optuna.Trial):
    # Sample parameters.
    x = trial.suggest_int('x', 0, 10)
    y = trial.suggest_categorical('y', [-10, -5, 0, 5, 10])

    # Check duplication and skip if it's detected.
    for t in trial.study.trials:
        if t.state != optuna.structs.TrialState.COMPLETE:
            continue

        if t.params == trial.params:
            return t.value  # Return the previous value without re-evaluating it.

            # # Note that if duplicate parameter sets are suggested too frequently,
            # # you can use the pruning mechanism of Optuna to mitigate the problem.
            # # By raising `TrialPruned` instead of just returning the previous value,
            # # the sampler is more likely to avoid sampling the parameters in the succeeding trials.
            #
            # raise optuna.structs.TrialPruned('Duplicate parameter set')

    # Evaluate parameters.
    return x + y

# Start study.
study = optuna.create_study()

unique_trials = 20
while unique_trials > len(set(str(t.params) for t in study.trials)):
    study.optimize(objective, n_trials=1)
于 2019-11-13T09:22:27.557 回答
1

要第二个@sile 的代码注释,您可以编写一个修剪器,例如:

class RepeatPruner(BasePruner):
    def prune(self, study, trial):
        # type: (Study, FrozenTrial) -> bool

        trials = study.get_trials(deepcopy=False)
        completed_trials = [t.params for t in trials if t.state == TrialState.COMPLETE]
        n_trials = len(completed_trials)

        if n_trials == 0:
            return False

        if trial.params in completed_trials:
            return True

        return False

然后将修剪器称为:

study = optuna.create_study(study_name=study_name, storage=storage, load_if_exists=True, pruner=RepeatPruner())
于 2020-04-22T00:28:53.377 回答