权重衰减(Weight Decay)

当看到 weight decay 时,从字面意思指的是权重参数衰减,会觉得其和 L2 正则化是等价,因为 L2 正则化也能够达到权重衰减的作用,其实概念是不同的。L2 一般作为正则化项是添加到损失函数中,作为损失计算的一部分,而后参与反向传播,而 weight decay 则是直接添加到参数更新环节,即:优化器内做的计算操作。

L2 正则化
weight decay

无论是 L2 正则化,还是 Weight Decay,其作用都是为了减小过拟合。怎么通俗的理解通过衰减权重来降低过拟合呢?

比如:我们的模型已经能够拟合大部分的样本了,但是还有一些在举例正常样本较远的一些异常的样本点,模型没有那么聪明,它仅仅按照我们要求的最小损失来拟合所有样本。此时,模型发现由于哪些较远的异常点的存在,使得我的损失很高,所以迫不得已就得去拟合这些异常的样本。这些异常样本的梯度值也会变得很大,模型的参数由于梯度值较大也得大幅度的调整。

此时,如果我们加入了权重衰减,那么就相当于碰到异常样本时,本来参数需要 + 8 调整,此时我们衰减一部分,就 + 3,是不是就相当于模型通过权重衰减来减小了过拟合。

接下来,回到正题上。

weight decay 计算发生在优化器内部,而 L2 正则化添加到损失函数中。那么,常见的优化器如何添加 weight decay 项呢?

1. SGD with Momentum 中的 Weight Decay

我们还是以 PyTorch 中实现的优化器来看看,weight decay 到底是如何参与计算的。先看下 SGD with Monmetum 的源码如下:

def sgd(params: List[Tensor],
        d_p_list: List[Tensor],
        momentum_buffer_list: List[Optional[Tensor]],
        *,
        weight_decay: float,
        momentum: float,
        lr: float,
        dampening: float,
        nesterov: bool):

    for i, param in enumerate(params):

        d_p = d_p_list[i]
        if weight_decay != 0:
            d_p = d_p.add(param, alpha=weight_decay)

        if momentum != 0:
            buf = momentum_buffer_list[i]

            if buf is None:
                buf = torch.clone(d_p).detach()
                momentum_buffer_list[i] = buf
            else:
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

            if nesterov:
                d_p = d_p.add(buf, alpha=momentum)
            else:
                d_p = buf

        param.add_(d_p, alpha=-lr)

代码中 d_p 表示参数的梯度值,GD with Monmetum 在计算动量之前,会添加 weight decay 运算到梯度值上,即:d_p = d_p + param * weight_decay,然后再进行梯度动量的累积。

2. AdaGrad 中的 Weight Decay

def adagrad(params: List[Tensor],
            grads: List[Tensor],
            state_sums: List[Tensor],
            state_steps: List[int],
            *,
            lr: float,
            weight_decay: float,
            lr_decay: float,
            eps: float):

    for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps):
        if weight_decay != 0:
            if grad.is_sparse:
                raise RuntimeError("weight_decay option is not compatible with sparse gradients")
            grad = grad.add(param, alpha=weight_decay)

        clr = lr / (1 + (step - 1) * lr_decay)

        if grad.is_sparse:
            grad = grad.coalesce()  # the update is non-linear so indices must be unique
            grad_indices = grad._indices()
            grad_values = grad._values()
            size = grad.size()

            state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
            std = state_sum.sparse_mask(grad)
            std_values = std._values().sqrt_().add_(eps)
            param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
        else:
            state_sum.addcmul_(grad, grad, value=1)
            std = state_sum.sqrt().add_(eps)
            param.addcdiv_(grad, std, value=-clr)

AdaGrad 也是和 SGD with Momentum 一样,在累积平方梯度之前,先对梯度值进行 weight decay 运算,即:grad = grad.add(param, alpha=weight_decay) => grad = grad + param * weight_decay

3. RMSProp 中的 Weight Decay

def rmsprop(params: List[Tensor],
            grads: List[Tensor],
            square_avgs: List[Tensor],
            grad_avgs: List[Tensor],
            momentum_buffer_list: List[Tensor],
            *,
            lr: float,
            alpha: float,
            eps: float,
            weight_decay: float,
            momentum: float,
            centered: bool):

    for i, param in enumerate(params):
        grad = grads[i]
        square_avg = square_avgs[i]

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)

        if centered:
            grad_avg = grad_avgs[i]
            grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
            avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(eps)
        else:
            avg = square_avg.sqrt().add_(eps)

        if momentum > 0:
            buf = momentum_buffer_list[i]
            buf.mul_(momentum).addcdiv_(grad, avg)
            param.add_(buf, alpha=-lr)
        else:
            param.addcdiv_(grad, avg, value=-lr)

RMSPRop 也是一样在计算移动平方梯度之前,先对梯度进行 weight decay 运算,即:grad = grad.add(param, alpha=weight_decay)。

4. Adam 中的 Weight Decay

def adam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[int],
         *,
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float):

    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)

Adam 也和前面都一样,先对梯度进行 weight decay 运算。grad = grad.add(param, alpha=weight_decay)

5. AdamW 中的 Weight Decay

def adamw(params: List[Tensor],
          grads: List[Tensor],
          exp_avgs: List[Tensor],
          exp_avg_sqs: List[Tensor],
          max_exp_avg_sqs: List[Tensor],
          state_steps: List[int],
          *,
          amsgrad: bool,
          beta1: float,
          beta2: float,
          lr: float,
          weight_decay: float,
          eps: float):

    for i, param in enumerate(params):
        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]

        # Perform stepweight decay
        param.mul_(1 - lr * weight_decay)

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)

AdamW 和前面 Adam 有些不同,如果你细心的话,也会发现 AdamW 和 Adam 实现的代码几乎是相同的,就关于 weight decay 这部分代码不同:

# AdamW
# param = param 1- lr * param * weight_decay)
param.mul_(1 - lr * weight_decay)

# Adam
if weight_decay != 0:
   grad = grad.add(param, alpha=weight_decay)

Adam 是将 weight decay 在一开始先将 weight decay 计算放在了梯度计算部分,而 AdamW 则是直接对参数进行计算,然后再进行后续的计算。如果 weight decay = 0 的话,AdamW 和 Adam 从代码来看,似乎是等价的,只有当 weight decay 不为 0 时,两个的计算方式才会有区别。为此,我做了一个小实验,将 AdamW 默认的 weight_decay 值由默认的 1e-3 该为0,则计算出的结果和 Adam 是完全一样的。

上面是 weight_decay = 0 的 AdamW,下面是 Adam 结果

一些资料提到,AdamW 正确实现了权重衰减,是对 Adam 的优化。

未经允许不得转载:一亩三分地 » 权重衰减(Weight Decay)