当看到 weight decay 时,从字面意思指的是权重参数衰减,会觉得其和 L2 正则化是等价,因为 L2 正则化也能够达到权重衰减的作用,其实概念是不同的。L2 一般作为正则化项是添加到损失函数中,作为损失计算的一部分,而后参与反向传播,而 weight decay 则是直接添加到参数更新环节,即:优化器内做的计算操作。
无论是 L2 正则化,还是 Weight Decay,其作用都是为了减小过拟合。怎么通俗的理解通过衰减权重来降低过拟合呢?
比如:我们的模型已经能够拟合大部分的样本了,但是还有一些在举例正常样本较远的一些异常的样本点,模型没有那么聪明,它仅仅按照我们要求的最小损失来拟合所有样本。此时,模型发现由于哪些较远的异常点的存在,使得我的损失很高,所以迫不得已就得去拟合这些异常的样本。这些异常样本的梯度值也会变得很大,模型的参数由于梯度值较大也得大幅度的调整。
此时,如果我们加入了权重衰减,那么就相当于碰到异常样本时,本来参数需要 + 8 调整,此时我们衰减一部分,就 + 3,是不是就相当于模型通过权重衰减来减小了过拟合。
接下来,回到正题上。
weight decay 计算发生在优化器内部,而 L2 正则化添加到损失函数中。那么,常见的优化器如何添加 weight decay 项呢?
1. SGD with Momentum 中的 Weight Decay
我们还是以 PyTorch 中实现的优化器来看看,weight decay 到底是如何参与计算的。先看下 SGD with Monmetum 的源码如下:
def sgd(params: List[Tensor], d_p_list: List[Tensor], momentum_buffer_list: List[Optional[Tensor]], *, weight_decay: float, momentum: float, lr: float, dampening: float, nesterov: bool): for i, param in enumerate(params): d_p = d_p_list[i] if weight_decay != 0: d_p = d_p.add(param, alpha=weight_decay) if momentum != 0: buf = momentum_buffer_list[i] if buf is None: buf = torch.clone(d_p).detach() momentum_buffer_list[i] = buf else: buf.mul_(momentum).add_(d_p, alpha=1 - dampening) if nesterov: d_p = d_p.add(buf, alpha=momentum) else: d_p = buf param.add_(d_p, alpha=-lr)
代码中 d_p 表示参数的梯度值,GD with Monmetum 在计算动量之前,会添加 weight decay 运算到梯度值上,即:d_p = d_p + param * weight_decay,然后再进行梯度动量的累积。
2. AdaGrad 中的 Weight Decay
def adagrad(params: List[Tensor], grads: List[Tensor], state_sums: List[Tensor], state_steps: List[int], *, lr: float, weight_decay: float, lr_decay: float, eps: float): for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps): if weight_decay != 0: if grad.is_sparse: raise RuntimeError("weight_decay option is not compatible with sparse gradients") grad = grad.add(param, alpha=weight_decay) clr = lr / (1 + (step - 1) * lr_decay) if grad.is_sparse: grad = grad.coalesce() # the update is non-linear so indices must be unique grad_indices = grad._indices() grad_values = grad._values() size = grad.size() state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) std = state_sum.sparse_mask(grad) std_values = std._values().sqrt_().add_(eps) param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) else: state_sum.addcmul_(grad, grad, value=1) std = state_sum.sqrt().add_(eps) param.addcdiv_(grad, std, value=-clr)
AdaGrad 也是和 SGD with Momentum 一样,在累积平方梯度之前,先对梯度值进行 weight decay 运算,即:grad = grad.add(param, alpha=weight_decay) => grad = grad + param * weight_decay
3. RMSProp 中的 Weight Decay
def rmsprop(params: List[Tensor], grads: List[Tensor], square_avgs: List[Tensor], grad_avgs: List[Tensor], momentum_buffer_list: List[Tensor], *, lr: float, alpha: float, eps: float, weight_decay: float, momentum: float, centered: bool): for i, param in enumerate(params): grad = grads[i] square_avg = square_avgs[i] if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) if centered: grad_avg = grad_avgs[i] grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(eps) else: avg = square_avg.sqrt().add_(eps) if momentum > 0: buf = momentum_buffer_list[i] buf.mul_(momentum).addcdiv_(grad, avg) param.add_(buf, alpha=-lr) else: param.addcdiv_(grad, avg, value=-lr)
RMSPRop 也是一样在计算移动平方梯度之前,先对梯度进行 weight decay 运算,即:grad = grad.add(param, alpha=weight_decay)。
4. Adam 中的 Weight Decay
def adam(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[int], *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float): for i, param in enumerate(params): grad = grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step = state_steps[i] bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) else: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) step_size = lr / bias_correction1 param.addcdiv_(exp_avg, denom, value=-step_size)
Adam 也和前面都一样,先对梯度进行 weight decay 运算。grad = grad.add(param, alpha=weight_decay)
5. AdamW 中的 Weight Decay
def adamw(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[int], *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float): for i, param in enumerate(params): grad = grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step = state_steps[i] # Perform stepweight decay param.mul_(1 - lr * weight_decay) bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) else: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) step_size = lr / bias_correction1 param.addcdiv_(exp_avg, denom, value=-step_size)
AdamW 和前面 Adam 有些不同,如果你细心的话,也会发现 AdamW 和 Adam 实现的代码几乎是相同的,就关于 weight decay 这部分代码不同:
# AdamW # param = param 1- lr * param * weight_decay) param.mul_(1 - lr * weight_decay) # Adam if weight_decay != 0: grad = grad.add(param, alpha=weight_decay)
Adam 是将 weight decay 在一开始先将 weight decay 计算放在了梯度计算部分,而 AdamW 则是直接对参数进行计算,然后再进行后续的计算。如果 weight decay = 0 的话,AdamW 和 Adam 从代码来看,似乎是等价的,只有当 weight decay 不为 0 时,两个的计算方式才会有区别。为此,我做了一个小实验,将 AdamW 默认的 weight_decay 值由默认的 1e-3 该为0,则计算出的结果和 Adam 是完全一样的。
一些资料提到,AdamW 正确实现了权重衰减,是对 Adam 的优化。