在 Optuna 中调整参数时,我的可能参数空间中有一个无效的子空间。在我的特殊情况下,如果我正在调整的两个参数都接近零(< 1e-5),则它们可能会导致极长时间的试验(我想避免),即:
A > 1e-5 | A < 1e-5 | |
---|---|---|
B > 1e-5 | 好的 | 好的 |
B < 1e-5 | 好的 | 暂停 |
当 A < 1e-5 和 B < 1e-5 时,我显然能够捕捉到这种极端情况,但是我应该如何让 Optuna 知道这是一次无效的试验?我不想更改 A 和 B 的采样范围以排除 < 1e-5 的值,因为如果 A 和 B 中只有一个小于 1e-5 就可以了。
到目前为止,我有两个想法:
引发 Optuna 修剪异常
optuna.exceptions.TrialPruned
。这将在代码超时之前修剪试验,但我不确定这是否会告诉 Optuna 这是要评估的搜索空间的糟糕区域。如果它确实引导调整远离这种边缘情况,那么我认为这是最好的选择。返回一些固定的试验分数,例如 0。我知道我的试验的分数在 0 和 1 之间,因此如果达到这个无效的边缘情况,我可以返回可能的最低分数 0。但是,如果大多数试验分数是 0.5 或更大,则边缘情况的值 0 成为极端异常值。
MWE:
import optuna
class MWETimeoutTuner:
def __call__(self, trial):
# Using a limit of 0.1 rather than 1e-5 so the edge case is triggered quicker
lim = 0.1
trial_a = trial.suggest_float('a', 0.0, 1.0)
trial_b = trial.suggest_float('d', 0.0, 1.0)
trial_c = trial.suggest_float('c', 0.0, 1.0)
trial_d = trial.suggest_float('d', 0.0, 1.0)
# Without this, we end up stuck in the infinite loop in _func_that_can_timeout
# But is pruning the trial the best way to way to avoid an invalid parameter configuration?
if trial_a < lim and trial_b < lim:
raise optuna.exceptions.TrialPruned
def _func_that_can_timeout(a, b, c, d):
# This mocks the timeout situation due to an invalid parameter configuration.
if a < lim and b < lim:
print('TIMEOUT:', a, b)
while True:
pass
# The maximum possible score would be 2 (c=1, d=1, a=0, b=0)
# However, as only one of a and b can be less than 0.1, the actual maximum is 1.9.
# Either (c=1, d=1, a=0, b=0.1) or (c=1, d=1, a=0.1, b=0)
return c + d - a - b
score = _func_that_can_timeout(trial_a, trial_b, trial_c, trial_d)
return score
if __name__ == "__main__":
tuner = MWETimeoutTuner()
n_trials = 1000
direction = 'maximize'
study_uid = "MWETimeoutTest"
study = optuna.create_study(direction=direction, study_name=study_uid)
study.optimize(tuner, n_trials=n_trials)
我发现了这个相关问题,它建议根据已采样的现有值更改采样过程。在 MWE 中,这看起来像:
trial_a = trial.suggest_float('a', 0.0, 1.0)
if trial_a < lim:
trial_b = trial.suggest_float('b', lim, 1.0)
else:
trial_b = trial.suggest_float('b', 0.0, 1.0)
但是,在测试时,这会产生以下警告:
RuntimeWarning:名称为“b”的分布参数值不一致!这可能是配置错误。Optuna 允许在试用中多次调用具有相同名称的相同发行版。当参数值不一致时,optuna 仅使用第一次调用的值并忽略所有后续调用。使用这些值:{'low': 0.1, 'high': 1.0}。>
所以这似乎不是一个有效的解决方案。
在 MWE 中,提高修剪异常是有效的,并且找到了(接近)最优值。似乎在写这个问题时我自己几乎已经回答了修剪是要走的路,除非有更好的解决方案?