我们先回顾下梯度下降法参数更新的公式:
从公式,可以很清楚的看到,参数能否学习就看学习率 LR 和梯度 G 了。如果某一点的梯度是 0 的话,那么参数就无法更新。什么时候会出现梯度为 0 的情况?比如局部极小值,另外就是鞍点(该点的梯度为 0,但不是极小值),当然如果学习率为 0 或者太小的话,参数也无法得到学习。我们此处暂且不考虑对 LR 的优化,将重心放在对梯度 G 的优化上。
我们要介绍的 momentum 动量法就是对梯度下降公式中的 G 进行优化的方法。它是如何进行优化的呢?我们还是先看下 momentum 的数学公式,即:梯度下降公式、参数的更新公式变成下面的样子了。
- μ 表示动量的参数,该值通常设置为 0.9;
- v 表示 velocity,可以理解为沿着某一方向的速度,动量;
- lr 表示学习率;
- g 表示 gradient 梯度;
- p 表示要更新的参数。
使用 momentum 更新参数时,可以分为两个步骤,首先计算某一个方向的速度,然后更新参数即可。它是如何计算的呢?
- 学习率 LR = 0.1;
- 梯度值固定为 2;
- 参数初始值为 1;
- 动量参数值为 0.9.
第一次迭代时:\(v_t\) = 0 \(v_{t+1}\) = 2,更新参数:1 – 0.1 * 2 = 0.8。
第二次迭代时:\(v_{t}\) = 2 \(v_{t+1}\) = 2 * 0.9 + 2 = 3.8,更新参数:0.8 – 0.1 * 3.8 = 0.42。
第三次迭代时:\(v_{t}\) = 3.8 \(v_{t+1}\) = 3.8 * 0.9 + 2 = 5.42,更新参数:0.42 – 0.1 * 5.42 = -0.122。
那么,假设碰到梯度为 0 的情况,比如参数还会更新吗?
我们先把 momentum 的公式表示如下:
假设此时进行第四次参数更新,g 的值为 0 表示当前点的梯度为 0,那么表示当前可能是鞍点或者局部极小值。通过公式,我们发现即使梯度为 0,第三项 \(lr * μ * v_{t}\) 由于积累了一些能量,所以也能够使得参数进行更新。就好像,我们从山上往下跑,虽然碰到了平地,但是仍然会有一些势能使得我们向前移动。
第四次迭代时,注意此时的梯度值为 0:\(v_{t}\) = 5.42,\(v_{t+1}\) =5.42 * 0.9 + 0 = 4.878。更新参数:-0.122 – 0.1 * 4.878 = -0.6098。
接下来,我们使用 Pytorch 来验证下这部分的计算,代码如下:
import torch import torch.optim as optim if __name__ == '__main__': # 构造初始参数 param = torch.tensor([1], dtype=torch.float32) # 设置梯度值 param.grad = torch.tensor([2], dtype=torch.float32) # 使用 SGD 优化器 # 动量参数 momentum 为 0.9 # 学习率为 0.1 optimizer = optim.SGD([param], lr=0.1, momentum=0.9) # 1. 初始化动量 print(param) # 2. 第一次更新后的值 optimizer.step() print(optimizer.state) # 3. 第二次更新后的值 optimizer.step() print(optimizer.state) # 4. 第三次更新后的值 optimizer.step() print(optimizer.state) # 5. 此时碰到鞍点,梯度为 0 param.grad = torch.tensor([0], dtype=torch.float32) optimizer.step() print(optimizer.state)
程序输出结果:
tensor([1.]) defaultdict(<class 'dict'>, {tensor([0.8000]): {'momentum_buffer': tensor([2.])}}) defaultdict(<class 'dict'>, {tensor([0.4200]): {'momentum_buffer': tensor([3.8000])}}) defaultdict(<class 'dict'>, {tensor([-0.1220]): {'momentum_buffer': tensor([5.4200])}}) defaultdict(<class 'dict'>, {tensor([-0.6098]): {'momentum_buffer': tensor([4.8780])}})
PyTorch 中 SGD 的计算过程: