BP (Back Propagation)算法也叫做误差反向传播算法,它用于求解模型的参数梯度,从而使用梯度下降法来更新网络参数。它的基本工作流程如下:
- 通过正向传播得到误差,所谓正向传播指的是数据从输入到输出层,经过层层计算得到预测值,并利用损失函数得到预测值和真实值之前的误差。
- 通过反向传播把误差传递给模型的参数,从而对网络参数进行适当的调整,缩小预测值和真实值之间的误差。
反向传播算法是利用链式法则进行梯度求解及权重更新的。对于复杂的复合函数,我们将其拆分为一系列的加减乘除或指数,对数,三角函数等初等函数,通过链式法则完成复合函数的求导。
我们通过一个例子来简单理解下 BP 算法进行网络参数更新的过程:
为了能够把计算过程描述的更详细一些,上图中一个矩形代表一个神经元,每个神经元中分别是值和激活值的计算结果和其对应的公式,最终计算出真实值和预测值之间的误差 0.2984. 其中
- 由下向上看,最下层绿色的两个圆代表两个输入值
- 右侧的8个数字,最下面4个表示 w1、w2、w3、w4 的参数初始值,最上面的4个数字表示 w5、w6、w7、w8 的参数初始值
- b1 值为 0.35,b2 值为 0.60
- 预测结果分别为: 0.7514、0.7729
我们首先计算 w5 和 w7 两个权重的梯度,然后使用梯度下降更新这两个参数。
计算出了梯度值,接下来使用使用梯度下降公式来更新模型参数,假设:学习率为 0.5,则:
接下来,我们计算 w1 的梯度,以及更新该参数:
接下来更新该参数:
其他的网络参数更新过程和上面的过程是一样的。下面我们使用代码构建上面的网络,并进行一次正向传播和反向传播。
import torch import torch.nn as nn import torch.optim as optim class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.linear1 = nn.Linear(2, 2) self.linear2 = nn.Linear(2, 2) # 网络参数初始化 self.linear1.weight.data = torch.tensor([[0.15, 0.20], [0.25, 0.30]]) self.linear2.weight.data = torch.tensor([[0.40, 0.45], [0.50, 0.55]]) self.linear1.bias.data = torch.tensor([0.35, 0.35]) self.linear2.bias.data = torch.tensor([0.60, 0.60]) def forward(self, x): x = self.linear1(x) x = torch.sigmoid(x) x = self.linear2(x) x = torch.sigmoid(x) return x if __name__ == '__main__': inputs = torch.tensor([[0.05, 0.10]]) target = torch.tensor([[0.01, 0.99]]) # 获得网络输出值 net = Net() output = net(inputs) # print(output) # tensor([[0.7514, 0.7729]], grad_fn=<SigmoidBackward>) # 计算误差 loss = torch.sum((output - target) ** 2) / 2 # print(loss) # tensor(0.2984, grad_fn=<DivBackward0>) # 优化方法 optimizer = optim.SGD(net.parameters(), lr=0.5) # 梯度清零 optimizer.zero_grad() # 反向传播 loss.backward() # 打印 w5、w7、w1 的梯度值 print(net.linear1.weight.grad.data) # tensor([[0.0004, 0.0009], # [0.0005, 0.0010]]) print(net.linear2.weight.grad.data) # tensor([[ 0.0822, 0.0827], # [-0.0226, -0.0227]]) # 打印网络参数 optimizer.step() print(net.state_dict()) # OrderedDict([('linear1.weight', tensor([[0.1498, 0.1996], [0.2498, 0.2995]])), # ('linear1.bias', tensor([0.3456, 0.3450])), # ('linear2.weight', tensor([[0.3589, 0.4087], [0.5113, 0.5614]])), # ('linear2.bias', tensor([0.5308, 0.6190]))])
从代码可以看出,我们手算结果和程序的运行结果是一致的。