我已经尝试根据该算法编写优化器: 算法 文章,我从中获取了算法(第 5 页):https ://arxiv.org/pdf/2106.02720.pdf
但我的解决方案不是优化任何东西。当我试图修复它时,我发现 d_p 仅由零组成,但我不明白为什么(
import torch
from torch.optim import optimizer
class BAMSGD(optimizer.Optimizer):
def __init__(self, params, lr=float(1e-2), weight_decay=0, gamma=0.9):
default = dict(lr=lr, weight_decay=weight_decay, gamma=gamma)
super(BAMSGD, self).__init__(params, default)
def __setstate__(self, state):
super(BAMSGD, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss=closure()
for group in self.param_groups:
weight_decay = group["weight_decay"]
learning_rate = group["lr"]
gamma = group["gamma"]
for p in range(len(group["params"])):
t = p
p = group["params"][p]
param_state = self.state[p]
if "wt_ag" not in param_state:
param_state["wt_ag"] = copy.deepcopy(p.data)
if "wt" not in param_state:
param_state["wt"] = copy.deepcopy(p.data)
beta_t = 1 + t/6
gamma_t = gamma * (t+1)
b = 200 # at the moment it's just a random number
wt_ag = param_state["wt_ag"]
wt = param_state["wt"]
if p.grad is None:
continue
d_p = p.grad.data
d_p.add_(1e-2, p.data)
'''if weight_decay != 0:
d_p.add_(weight_decay, p.data)'''
# buffer.mul_(1-pow(beta_t, -1)) #
# buffer2.mul_(pow(beta_t, -1)) #
# buffer.add_(buffer2) #
wt.add(-(gamma * d_p)) ##
wt_md = pow(beta_t, -1) * wt + (1-pow(beta_t, -1)) * wt_ag
ww = wt - gamma * d_p
w_t1 = min(1, b/(wt.norm())) * wt
w_t1ag = pow(beta_t, -1) * w_t1 + (1 - pow(beta_t, -1)) * wt_ag
# self.__setstate__(w_t1ag, w_t1)
p.data.add_(-group['lr'], w_t1)
return loss
帮助我,我是优化新手,我可能在代码中犯了很多错误(