Simba: 符号梯度的可行性

黑暗源头

注: Lion 建议 $\beta_1 = 0.9, \beta_2 = 0.99$.

实验结果

HR@10NDCG@10Loss $\downarrow$
SASRec (Lion)0.06900.03810.145131
SASRec (Sign Gradient)0.04200.02130.665475
SASRec (Simba)0.01950.00940.872968
HR@10NDCG@10Loss $\downarrow$
SASRec (Simba + $m_{t-1}$)0.06930.03780.145657
HR@10NDCG@10Loss $\downarrow$
SASRec (Simba + (1))0.02960.01520.578699
HR@10NDCG@10Loss $\downarrow$
SASRec (Simba + ($c_t = m_{t-1}$))0.06670.03580.142691

代码


import torch
from torch.optim.optimizer import Optimizer

class Simba(Optimizer):

  def __init__(
        self, 
        params, 
        lr=1e-4, 
        window_size: int = 128,
        threshold: float = 0.1,
        weight_decay=0.0
  ):
    """Initialize the hyperparameters.

    Args:
      params (iterable): iterable of parameters to optimize or dicts defining
        parameter groups
      lr (float, optional): learning rate (default: 1e-4)
      window_size (int): the size of sliding window
      threshold (float): the threshold determining using the accumulated sign gradient or the current sign gradient
      weight_decay (float, optional): weight decay coefficient (default: 0)
    """

    if not 0.0 <= lr:
      raise ValueError('Invalid learning rate: {}'.format(lr))
    defaults = dict(lr=lr, window_size=window_size, threshold=threshold, weight_decay=weight_decay)
    super().__init__(params, defaults)

  @torch.no_grad()
  def step(self, closure=None):
    """Performs a single optimization step.

    Args:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.

    Returns:
      the loss.
    """
    loss = None
    if closure is not None:
      with torch.enable_grad():
        loss = closure()

    for group in self.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue

        # Perform stepweight decay
        p.data.mul_(1 - group['lr'] * group['weight_decay'])

        grad = p.grad
        state = self.state[p]
        # State initialization
        if len(state) == 0:
          # Exponential moving average of gradient values
          state['sign_counts'] = torch.zeros_like(p, dtype=torch.int16)

        sign_counts = state['sign_counts']
        window_size, threshold = group['window_size'], group['threshold']
        threshold = int(window_size * threshold)
        # Weight update
        update = torch.where(
            sign_counts.abs() > threshold,
            sign_counts.sign(),
            grad.sign()
        )

        p.add_(update, alpha=-group['lr'])

        # Decay the momentum running average coefficient
        sign_counts.add_(grad.sign().to(torch.int16)).clamp_(-window_size, window_size)
    return loss

参考文献

  1. Chen X., Liang C., Huang D., Real E., Wang K., Liu Y., Pham H., Dong X., Luong T., Hsieh C., Liu Y., and Le Q. V. Symbolic Discovery of Optimization Algorithms. NeurIPS, 2024. [PDF] [Code]