0

我已经尝试根据该算法编写优化器: 算法 文章,我从中获取了算法(第 5 页):https ://arxiv.org/pdf/2106.02720.pdf

但我的解决方案不是优化任何东西。当我试图修复它时,我发现 d_p 仅由零组成,但我不明白为什么(

import torch
from torch.optim import optimizer
class BAMSGD(optimizer.Optimizer):
  def __init__(self, params, lr=float(1e-2), weight_decay=0, gamma=0.9):
    default = dict(lr=lr, weight_decay=weight_decay, gamma=gamma)
    super(BAMSGD, self).__init__(params, default)

  def __setstate__(self, state):
    super(BAMSGD, self).__setstate__(state)

  def step(self, closure=None):

    loss = None
    if closure is not None:
      loss=closure()

    for group in self.param_groups:
      weight_decay = group["weight_decay"]
      learning_rate = group["lr"]
      gamma = group["gamma"]
      for p in range(len(group["params"])):
        t = p
        p = group["params"][p]
        param_state = self.state[p]
        if "wt_ag" not in param_state:
          param_state["wt_ag"] = copy.deepcopy(p.data)
        if "wt" not in param_state:
          param_state["wt"] = copy.deepcopy(p.data)
        beta_t = 1 + t/6
        gamma_t = gamma * (t+1)
        b = 200 # at the moment it's just a random number
        wt_ag = param_state["wt_ag"]
        wt = param_state["wt"]
        if p.grad is None:
            continue
        d_p = p.grad.data
        d_p.add_(1e-2, p.data)
        '''if weight_decay != 0:
            d_p.add_(weight_decay, p.data)'''
        # buffer.mul_(1-pow(beta_t, -1)) #
        # buffer2.mul_(pow(beta_t, -1)) #
        # buffer.add_(buffer2) #
        wt.add(-(gamma * d_p)) ##
        wt_md = pow(beta_t, -1) * wt + (1-pow(beta_t, -1)) * wt_ag
        ww = wt - gamma * d_p
        w_t1 = min(1, b/(wt.norm())) * wt
        w_t1ag = pow(beta_t, -1) * w_t1 + (1 - pow(beta_t, -1)) * wt_ag
        # self.__setstate__(w_t1ag, w_t1)
        p.data.add_(-group['lr'], w_t1)
    return loss

帮助我,我是优化新手,我可能在代码中犯了很多错误(

4

0 回答 0