当看到 weight decay 时,从字面意思看是权重参数衰减,很容易让人觉得它和 L2 正则化是等价的,因为 L2 正则化也能让权重变小。但实际上,两者的概念和实现位置并不相同。
L2 正则化通常是作为一项加到损失函数里,参与损失计算,然后通过反向传播影响梯度。而 weight decay 则是在优化器的参数更新步骤中直接对权重进行衰减,属于优化器内部的操作。


无论是 L2 还是 weight decay,都能让权重变小,从而降低模型复杂度,减少过拟合。假设模型已经能很好地拟合大部分正常样本,但数据中还存在一些远离整体分布的异常点。模型为了最小化损失,会被迫去拟合这些异常点,因为它们会产生较大的误差,从而产生较大的梯度,让参数更新幅度变大。
如果加入权重衰减,参数更新的幅度会被 “压小”。原本可能要 +8 的更新,现在可能只变成 +3。这样模型就不会因为少数异常点而大幅调整参数,从而减少对噪声的拟合,降低过拟合。
weight decay 计算发生在优化器内部,而 L2 正则化添加到损失函数中。那么,常见的优化器如何添加 weight decay 项呢?
1. SGD with Momentum 中的 Weight Decay
我们还是以 PyTorch 中实现的优化器来看看,weight decay 到底是如何参与计算的。先看下 SGD with Monmetum 的源码如下:
def sgd(params: List[Tensor],
d_p_list: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool):
for i, param in enumerate(params):
d_p = d_p_list[i]
if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay)
if momentum != 0:
buf = momentum_buffer_list[i]
if buf is None:
buf = torch.clone(d_p).detach()
momentum_buffer_list[i] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
param.add_(d_p, alpha=-lr)
代码中 d_p 表示参数的梯度值,GD with Monmetum 在计算动量之前,会添加 weight decay 运算到梯度值上,即:d_p = d_p + param * weight_decay,然后再进行梯度动量的累积。
2. AdaGrad 中的 Weight Decay
def adagrad(params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[int],
*,
lr: float,
weight_decay: float,
lr_decay: float,
eps: float):
for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps):
if weight_decay != 0:
if grad.is_sparse:
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
grad = grad.add(param, alpha=weight_decay)
clr = lr / (1 + (step - 1) * lr_decay)
if grad.is_sparse:
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
std = state_sum.sparse_mask(grad)
std_values = std._values().sqrt_().add_(eps)
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
else:
state_sum.addcmul_(grad, grad, value=1)
std = state_sum.sqrt().add_(eps)
param.addcdiv_(grad, std, value=-clr)
AdaGrad 也是和 SGD with Momentum 一样,在累积平方梯度之前,先对梯度值进行 weight decay 运算,即:grad = grad.add(param, alpha=weight_decay) => grad = grad + param * weight_decay
3. RMSProp 中的 Weight Decay
def rmsprop(params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool):
for i, param in enumerate(params):
grad = grads[i]
square_avg = square_avgs[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
if centered:
grad_avg = grad_avgs[i]
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(eps)
else:
avg = square_avg.sqrt().add_(eps)
if momentum > 0:
buf = momentum_buffer_list[i]
buf.mul_(momentum).addcdiv_(grad, avg)
param.add_(buf, alpha=-lr)
else:
param.addcdiv_(grad, avg, value=-lr)
RMSPRop 也是一样在计算移动平方梯度之前,先对梯度进行 weight decay 运算,即:grad = grad.add(param, alpha=weight_decay)。
4. Adam 中的 Weight Decay
def adam(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float):
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
Adam 也和前面都一样,先对梯度进行 weight decay 运算。grad = grad.add(param, alpha=weight_decay)
5. AdamW 中的 Weight Decay
def adamw(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float):
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
AdamW 和前面 Adam 有些不同,如果你细心的话,也会发现 AdamW 和 Adam 实现的代码几乎是相同的,就关于 weight decay 这部分代码不同:
# AdamW # param = param 1- lr * param * weight_decay) param.mul_(1 - lr * weight_decay) # Adam if weight_decay != 0: grad = grad.add(param, alpha=weight_decay)
Adam 是将 weight decay 在一开始先将 weight decay 计算放在了梯度计算部分,而 AdamW 则是直接对参数进行计算,然后再进行后续的计算。如果 weight decay = 0 的话,AdamW 和 Adam 从代码来看,似乎是等价的,只有当 weight decay 不为 0 时,两个的计算方式才会有区别。为此,我做了一个小实验,将 AdamW 默认的 weight_decay 值由默认的 1e-3 该为0,则计算出的结果和 Adam 是完全一样的。

一些资料提到,AdamW 正确实现了权重衰减,是对 Adam 的优化。



冀公网安备13050302001966号