1. 多分类损失计算
我们在计算多分类损失时,使用的是多分类交叉熵损失。其简要的计算过程如下:假设:真实的标签为 [0, 0, 1, 0, 0],预测的分数为 [0.15, -0.34, 0.12, 0.67, 0.55],计算过程如下:
- 对每个 Logit 计算指数;
- 计算每个值的概率值;
- 计算正确标签对应的概率值的负对数值;
import torch.nn as nn
import torch
def test01():
labels = torch.tensor([2])
logits = torch.tensor([[0.15, -0.34, 0.12, 0.67, 0.55]])
loss = nn.CrossEntropyLoss()(logits, labels)
print(loss)
# 慢动作
# 1. 对每个分数计算指数
temp = torch.exp(logits)
# 2. 计算每个值的概率值
probas = temp / torch.sum(temp, dim=-1)
# 3. 计算正确标签对应的概率值的负对数值
loss = -torch.log(probas[0][2])
print(loss)
if __name__ == '__main__':
test01()
程序执行结果:
tensor(1.7804) tensor(1.7804)
在上面代码中,多分类场景下,每个样本的正确标签标识可以用 one-hot 表示为:[0, 0, 1, 0, 0]。
2. 多标签损失计算
多标签损失计算时,我们就不能使用 CrossEntropyLoss 损失函数,这是因为在交叉熵损失函数中,标签之间都是互斥的,即: 每个样本只有一个标签。
在多标签中,我们倾向于将每个样本的多个标签理解为相互独立的,即: 每个样本可以多个标签。此时,计算过程如下:
假设:某个样本有两个标签 1、3,我们就可以用 Multi Hot 来表示该样本的多个标签:[0, 1, 0, 1, 0],标签位置为1,其余为 0。每个位置预测的分数为:[0.15, -0.34, 0.12, 0.67, 0.55]
1. 先对 1 位置对应的分数使用 sigmoid 函数,计算其为 1 的概率;
2. 对 1 类别分数计算 -torch.log(logit_sigmoid), 对 0 类别分数计算 -torch.log(1-logit_sigmoid)
3. 将负对数值相加再计算均值即可得到某个样本多标签的损失值。
import torch.nn as nn
import torch
def test02():
# 模型预测分数
logits = torch.tensor([[0.15, -0.34, 0.12, 0.67, 0.55]])
# 真实多标签
# 标签要用小数
labels = torch.tensor([[0, 1, 0, 1, 0]], dtype=torch.float32)
criterion = nn.BCELoss()
loss = criterion(torch.sigmoid(logits), labels)
print(loss)
criterion = nn.BCEWithLogitsLoss()
loss = criterion(logits, labels)
print(loss)
criterion = nn.MultiLabelSoftMarginLoss()
loss = criterion(logits, labels)
print(loss)
# 慢动作
# 1. 先对每个分数计算 sigmoid 值
temp = torch.sigmoid(logits)
# 2. 对于正确标签计算负对数值 -log(temp),对于不正确标签计算 -log(1-temp)
loss = 0.0
for index, value in zip([0, 1, 0, 1, 0], temp[0]):
if index == 0:
loss += -torch.log(1-value)
if index == 1:
loss += -torch.log(value)
# 3. 对所有标签取均值
loss = loss / 5
print(loss)
if __name__ == '__main__':
test02()
程序执行结果:
tensor(0.7644) tensor(0.7644) tensor(0.7644) tensor(0.7644)
从计算过程来看,对于多分类,其标签是 one hot 表示形式 [0, 1, 0, 0, 0],而多标签分类中,每个样本的标签是 multi hot 表示形式 [0, 1, 0, 1, 0]。注意在使用 BCELoss、BCEWithLogitsLoss、MultiLabelSoftMarginLoss 计算多标签损失时,标签需要表示成 multi hot 形式,另外该 multi hot 中的每个元素必须是 torch.float 类型。

冀公网安备13050302001966号