SGD 优化器原理

我们先回顾下梯度下降法参数更新的公式:

从公式,可以很清楚的看到,参数能否学习就看学习率 LR 和梯度 G 了。如果某一点的梯度是 0 的话,那么参数就无法更新。什么时候会出现梯度为 0 的情况?比如局部极小值,另外就是鞍点(该点的梯度为 0,但不是极小值),当然如果学习率为 0 或者太小的话,参数也无法得到学习。我们此处暂且不考虑对 LR 的优化,将重心放在对梯度 G 的优化上。

我们要介绍的 momentum 动量法就是对梯度下降公式中的 G 进行优化的方法。它是如何进行优化的呢?我们还是先看下 momentum 的数学公式,即:梯度下降公式、参数的更新公式变成下面的样子了。

  1. μ 表示动量的参数,该值通常设置为 0.9;
  2. v 表示 velocity,可以理解为沿着某一方向的速度,动量;
  3. lr 表示学习率;
  4. g 表示 gradient 梯度;
  5. p 表示要更新的参数。

使用 momentum 更新参数时,可以分为两个步骤,首先计算某一个方向的速度,然后更新参数即可。它是如何计算的呢?

  1. 学习率 LR = 0.1;
  2. 梯度值固定为 2;
  3. 参数初始值为 1;
  4. 动量参数值为 0.9.

第一次迭代时:\(v_t\) = 0 \(v_{t+1}\) = 2,更新参数:1 – 0.1 * 2 = 0.8。
第二次迭代时:\(v_{t}\) = 2 \(v_{t+1}\) = 2 * 0.9 + 2 = 3.8,更新参数:0.8 – 0.1 * 3.8 = 0.42。
第三次迭代时:\(v_{t}\) = 3.8 \(v_{t+1}\) = 3.8 * 0.9 + 2 = 5.42,更新参数:0.42 – 0.1 * 5.42 = -0.122。

那么,假设碰到梯度为 0 的情况,比如参数还会更新吗?

我们先把 momentum 的公式表示如下:

假设此时进行第四次参数更新,g 的值为 0 表示当前点的梯度为 0,那么表示当前可能是鞍点或者局部极小值。通过公式,我们发现即使梯度为 0,第三项 \(lr * μ * v_{t}\) 由于积累了一些能量,所以也能够使得参数进行更新。就好像,我们从山上往下跑,虽然碰到了平地,但是仍然会有一些势能使得我们向前移动。

第四次迭代时,注意此时的梯度值为 0:\(v_{t}\) = 5.42,\(v_{t+1}\) =5.42 * 0.9 + 0 = 4.878。更新参数:-0.122 – 0.1 * 4.878 = -0.6098。

接下来,我们使用 Pytorch 来验证下这部分的计算,代码如下:

import torch
import torch.optim as optim


if __name__ == '__main__':

    # 构造初始参数
    param = torch.tensor([1], dtype=torch.float32)
    # 设置梯度值
    param.grad = torch.tensor([2], dtype=torch.float32)
    # 使用 SGD 优化器
    # 动量参数 momentum 为 0.9
    # 学习率为 0.1
    optimizer = optim.SGD([param], lr=0.1, momentum=0.9)

    # 1. 初始化动量
    print(param)

    # 2. 第一次更新后的值
    optimizer.step()
    print(optimizer.state)

    # 3. 第二次更新后的值
    optimizer.step()
    print(optimizer.state)

    # 4. 第三次更新后的值
    optimizer.step()
    print(optimizer.state)

    # 5. 此时碰到鞍点,梯度为 0
    param.grad = torch.tensor([0], dtype=torch.float32)
    optimizer.step()
    print(optimizer.state)

程序输出结果:

tensor([1.])
defaultdict(<class 'dict'>, {tensor([0.8000]): {'momentum_buffer': tensor([2.])}})
defaultdict(<class 'dict'>, {tensor([0.4200]): {'momentum_buffer': tensor([3.8000])}})
defaultdict(<class 'dict'>, {tensor([-0.1220]): {'momentum_buffer': tensor([5.4200])}})
defaultdict(<class 'dict'>, {tensor([-0.6098]): {'momentum_buffer': tensor([4.8780])}})

PyTorch 中 SGD 的计算过程:

未经允许不得转载:一亩三分地 » SGD 优化器原理