多标签损失计算

1. 多分类损失计算

我们在计算多分类损失时,使用的是多分类交叉熵损失。其简要的计算过程如下:假设:真实的标签为 [0, 0, 1, 0, 0],预测的分数为 [0.15, -0.34, 0.12, 0.67, 0.55],计算过程如下:

  1. 对每个 Logit 计算指数;
  2. 计算每个值的概率值;
  3. 计算正确标签对应的概率值的负对数值;

import torch.nn as nn
import torch


def test01():

    labels = torch.tensor([2])
    logits = torch.tensor([[0.15, -0.34, 0.12, 0.67, 0.55]])

    loss = nn.CrossEntropyLoss()(logits, labels)
    print(loss)

    # 慢动作
    # 1. 对每个分数计算指数
    temp = torch.exp(logits)
    # 2. 计算每个值的概率值
    probas = temp / torch.sum(temp, dim=-1)
    # 3. 计算正确标签对应的概率值的负对数值
    loss = -torch.log(probas[0][2])
    print(loss)

if __name__ == '__main__':
    test01()

程序执行结果:

tensor(1.7804)
tensor(1.7804)

在上面代码中,多分类场景下,每个样本的正确标签标识可以用 one-hot 表示为:[0, 0, 1, 0, 0]。

2. 多标签损失计算

多标签损失计算时,我们就不能使用 CrossEntropyLoss 损失函数,这是因为在交叉熵损失函数中,标签之间都是互斥的,即: 每个样本只有一个标签。
在多标签中,我们倾向于将每个样本的多个标签理解为相互独立的,即: 每个样本可以多个标签。此时,计算过程如下:
假设:某个样本有两个标签 1、3,我们就可以用 Multi Hot 来表示该样本的多个标签:[0, 1, 0, 1, 0],标签位置为1,其余为 0。每个位置预测的分数为:[0.15, -0.34, 0.12, 0.67, 0.55]


1. 先对 1 位置对应的分数使用 sigmoid 函数,计算其为 1 的概率;
2. 对 1 类别分数计算 -torch.log(logit_sigmoid), 对 0 类别分数计算 -torch.log(1-logit_sigmoid)
3. 将负对数值相加再计算均值即可得到某个样本多标签的损失值。

import torch.nn as nn
import torch

def test02():

    # 模型预测分数
    logits = torch.tensor([[0.15, -0.34, 0.12, 0.67, 0.55]])
    # 真实多标签
    # 标签要用小数
    labels = torch.tensor([[0, 1, 0, 1, 0]], dtype=torch.float32)

    criterion = nn.BCELoss()
    loss = criterion(torch.sigmoid(logits), labels)
    print(loss)

    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(logits, labels)
    print(loss)

    criterion = nn.MultiLabelSoftMarginLoss()
    loss = criterion(logits, labels)
    print(loss)

    # 慢动作
    # 1. 先对每个分数计算 sigmoid 值
    temp = torch.sigmoid(logits)
    # 2. 对于正确标签计算负对数值 -log(temp),对于不正确标签计算 -log(1-temp)
    loss = 0.0
    for index, value in zip([0, 1, 0, 1, 0], temp[0]):
        if index == 0:
            loss += -torch.log(1-value)
        if index == 1:
            loss += -torch.log(value)
    # 3. 对所有标签取均值
    loss = loss / 5
    print(loss)


if __name__ == '__main__':
    test02()

程序执行结果:

tensor(0.7644)
tensor(0.7644)
tensor(0.7644)
tensor(0.7644)

从计算过程来看,对于多分类,其标签是 one hot 表示形式 [0, 1, 0, 0, 0],而多标签分类中,每个样本的标签是 multi hot 表示形式 [0, 1, 0, 1, 0]。注意在使用 BCELoss、BCEWithLogitsLoss、MultiLabelSoftMarginLoss 计算多标签损失时,标签需要表示成 multi hot 形式,另外该 multi hot 中的每个元素必须是 torch.float 类型。

未经允许不得转载:一亩三分地 » 多标签损失计算