权重衰减(Weight Decay)

当看到 weight decay 时,从字面意思看是权重参数衰减,很容易让人觉得它和 L2 正则化是等价的,因为 L2 正则化也能让权重变小。但实际上,两者的概念和实现位置并不相同

L2 正则化通常是作为一项加到损失函数里,参与损失计算,然后通过反向传播影响梯度。而 weight decay 则是在优化器的参数更新步骤中直接对权重进行衰减,属于优化器内部的操作。

L2 正则化
weight decay

无论是 L2 还是 weight decay,都能让权重变小,从而降低模型复杂度,减少过拟合。假设模型已经能很好地拟合大部分正常样本,但数据中还存在一些远离整体分布的异常点。模型为了最小化损失,会被迫去拟合这些异常点,因为它们会产生较大的误差,从而产生较大的梯度,让参数更新幅度变大。

如果加入权重衰减,参数更新的幅度会被 “压小”。原本可能要 +8 的更新,现在可能只变成 +3。这样模型就不会因为少数异常点而大幅调整参数,从而减少对噪声的拟合,降低过拟合。

weight decay 计算发生在优化器内部,而 L2 正则化添加到损失函数中。那么,常见的优化器如何添加 weight decay 项呢?

1. SGD with Momentum 中的 Weight Decay

我们还是以 PyTorch 中实现的优化器来看看,weight decay 到底是如何参与计算的。先看下 SGD with Monmetum 的源码如下:

def sgd(params: List[Tensor],
        d_p_list: List[Tensor],
        momentum_buffer_list: List[Optional[Tensor]],
        *,
        weight_decay: float,
        momentum: float,
        lr: float,
        dampening: float,
        nesterov: bool):

    for i, param in enumerate(params):

        d_p = d_p_list[i]
        if weight_decay != 0:
            d_p = d_p.add(param, alpha=weight_decay)

        if momentum != 0:
            buf = momentum_buffer_list[i]

            if buf is None:
                buf = torch.clone(d_p).detach()
                momentum_buffer_list[i] = buf
            else:
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

            if nesterov:
                d_p = d_p.add(buf, alpha=momentum)
            else:
                d_p = buf

        param.add_(d_p, alpha=-lr)

代码中 d_p 表示参数的梯度值,GD with Monmetum 在计算动量之前,会添加 weight decay 运算到梯度值上,即:d_p = d_p + param * weight_decay,然后再进行梯度动量的累积。

2. AdaGrad 中的 Weight Decay

def adagrad(params: List[Tensor],
            grads: List[Tensor],
            state_sums: List[Tensor],
            state_steps: List[int],
            *,
            lr: float,
            weight_decay: float,
            lr_decay: float,
            eps: float):

    for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps):
        if weight_decay != 0:
            if grad.is_sparse:
                raise RuntimeError("weight_decay option is not compatible with sparse gradients")
            grad = grad.add(param, alpha=weight_decay)

        clr = lr / (1 + (step - 1) * lr_decay)

        if grad.is_sparse:
            grad = grad.coalesce()  # the update is non-linear so indices must be unique
            grad_indices = grad._indices()
            grad_values = grad._values()
            size = grad.size()

            state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
            std = state_sum.sparse_mask(grad)
            std_values = std._values().sqrt_().add_(eps)
            param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
        else:
            state_sum.addcmul_(grad, grad, value=1)
            std = state_sum.sqrt().add_(eps)
            param.addcdiv_(grad, std, value=-clr)

AdaGrad 也是和 SGD with Momentum 一样,在累积平方梯度之前,先对梯度值进行 weight decay 运算,即:grad = grad.add(param, alpha=weight_decay) => grad = grad + param * weight_decay

3. RMSProp 中的 Weight Decay

def rmsprop(params: List[Tensor],
            grads: List[Tensor],
            square_avgs: List[Tensor],
            grad_avgs: List[Tensor],
            momentum_buffer_list: List[Tensor],
            *,
            lr: float,
            alpha: float,
            eps: float,
            weight_decay: float,
            momentum: float,
            centered: bool):

    for i, param in enumerate(params):
        grad = grads[i]
        square_avg = square_avgs[i]

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)

        if centered:
            grad_avg = grad_avgs[i]
            grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
            avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(eps)
        else:
            avg = square_avg.sqrt().add_(eps)

        if momentum > 0:
            buf = momentum_buffer_list[i]
            buf.mul_(momentum).addcdiv_(grad, avg)
            param.add_(buf, alpha=-lr)
        else:
            param.addcdiv_(grad, avg, value=-lr)

RMSPRop 也是一样在计算移动平方梯度之前,先对梯度进行 weight decay 运算,即:grad = grad.add(param, alpha=weight_decay)。

4. Adam 中的 Weight Decay

def adam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[int],
         *,
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float):

    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)

Adam 也和前面都一样,先对梯度进行 weight decay 运算。grad = grad.add(param, alpha=weight_decay)

5. AdamW 中的 Weight Decay

def adamw(params: List[Tensor],
          grads: List[Tensor],
          exp_avgs: List[Tensor],
          exp_avg_sqs: List[Tensor],
          max_exp_avg_sqs: List[Tensor],
          state_steps: List[int],
          *,
          amsgrad: bool,
          beta1: float,
          beta2: float,
          lr: float,
          weight_decay: float,
          eps: float):

    for i, param in enumerate(params):
        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]

        # Perform stepweight decay
        param.mul_(1 - lr * weight_decay)

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)

AdamW 和前面 Adam 有些不同,如果你细心的话,也会发现 AdamW 和 Adam 实现的代码几乎是相同的,就关于 weight decay 这部分代码不同:

# AdamW
# param = param 1- lr * param * weight_decay)
param.mul_(1 - lr * weight_decay)

# Adam
if weight_decay != 0:
   grad = grad.add(param, alpha=weight_decay)

Adam 是将 weight decay 在一开始先将 weight decay 计算放在了梯度计算部分,而 AdamW 则是直接对参数进行计算,然后再进行后续的计算。如果 weight decay = 0 的话,AdamW 和 Adam 从代码来看,似乎是等价的,只有当 weight decay 不为 0 时,两个的计算方式才会有区别。为此,我做了一个小实验,将 AdamW 默认的 weight_decay 值由默认的 1e-3 该为0,则计算出的结果和 Adam 是完全一样的。

上面是 weight_decay = 0 的 AdamW,下面是 Adam 结果

一些资料提到,AdamW 正确实现了权重衰减,是对 Adam 的优化。

未经允许不得转载:一亩三分地 » 权重衰减(Weight Decay)