1. 多分类损失计算
我们在计算多分类损失时,使用的是多分类交叉熵损失。其简要的计算过程如下:假设:真实的标签为 [0, 0, 1, 0, 0],预测的分数为 [0.15, -0.34, 0.12, 0.67, 0.55],计算过程如下:
- 对每个 Logit 计算指数;
- 计算每个值的概率值;
- 计算正确标签对应的概率值的负对数值;
import torch.nn as nn import torch def test01(): labels = torch.tensor([2]) logits = torch.tensor([[0.15, -0.34, 0.12, 0.67, 0.55]]) loss = nn.CrossEntropyLoss()(logits, labels) print(loss) # 慢动作 # 1. 对每个分数计算指数 temp = torch.exp(logits) # 2. 计算每个值的概率值 probas = temp / torch.sum(temp, dim=-1) # 3. 计算正确标签对应的概率值的负对数值 loss = -torch.log(probas[0][2]) print(loss) if __name__ == '__main__': test01()
程序执行结果:
tensor(1.7804) tensor(1.7804)
在上面代码中,多分类场景下,每个样本的正确标签标识可以用 one-hot 表示为:[0, 0, 1, 0, 0]。
2. 多标签损失计算
多标签损失计算时,我们就不能使用 CrossEntropyLoss 损失函数,这是因为在交叉熵损失函数中,标签之间都是互斥的,即: 每个样本只有一个标签。
在多标签中,我们倾向于将每个样本的多个标签理解为相互独立的,即: 每个样本可以多个标签。此时,计算过程如下:
假设:某个样本有两个标签 1、3,我们就可以用 Multi Hot 来表示该样本的多个标签:[0, 1, 0, 1, 0],标签位置为1,其余为 0。每个位置预测的分数为:[0.15, -0.34, 0.12, 0.67, 0.55]
1. 先对 1 位置对应的分数使用 sigmoid 函数,计算其为 1 的概率;
2. 对 1 类别分数计算 -torch.log(logit_sigmoid), 对 0 类别分数计算 -torch.log(1-logit_sigmoid)
3. 将负对数值相加再计算均值即可得到某个样本多标签的损失值。
import torch.nn as nn import torch def test02(): # 模型预测分数 logits = torch.tensor([[0.15, -0.34, 0.12, 0.67, 0.55]]) # 真实多标签 # 标签要用小数 labels = torch.tensor([[0, 1, 0, 1, 0]], dtype=torch.float32) criterion = nn.BCELoss() loss = criterion(torch.sigmoid(logits), labels) print(loss) criterion = nn.BCEWithLogitsLoss() loss = criterion(logits, labels) print(loss) criterion = nn.MultiLabelSoftMarginLoss() loss = criterion(logits, labels) print(loss) # 慢动作 # 1. 先对每个分数计算 sigmoid 值 temp = torch.sigmoid(logits) # 2. 对于正确标签计算负对数值 -log(temp),对于不正确标签计算 -log(1-temp) loss = 0.0 for index, value in zip([0, 1, 0, 1, 0], temp[0]): if index == 0: loss += -torch.log(1-value) if index == 1: loss += -torch.log(value) # 3. 对所有标签取均值 loss = loss / 5 print(loss) if __name__ == '__main__': test02()
程序执行结果:
tensor(0.7644) tensor(0.7644) tensor(0.7644) tensor(0.7644)
从计算过程来看,对于多分类,其标签是 one hot 表示形式 [0, 1, 0, 0, 0],而多标签分类中,每个样本的标签是 multi hot 表示形式 [0, 1, 0, 1, 0]。注意在使用 BCELoss、BCEWithLogitsLoss、MultiLabelSoftMarginLoss 计算多标签损失时,标签需要表示成 multi hot 形式,另外该 multi hot 中的每个元素必须是 torch.float 类型。