基于语义嵌入 Zero-Shot Learning 文本分类

Zero-Shot Learning(零样本学习,ZSL) 是机器学习中的一种技术,指的是模型在没有见过某些类别的训练数据的情况下,仍然能够对这些新类别进行正确的预测。这种能力使得模型能够泛化到未见类别,减少对标注数据的依赖。

下面介绍的基于语义嵌入(Semantic Embedding)的零样本学习方法,主要是通过利用预训练模型的语言理解能力,通过计算文本与类别标签的相似度实现文本分类。

https://huggingface.co/jinaai/jina-embeddings-v3
https://huggingface.co/jinaai/xlm-roberta-flash-implementation/tree/main

1. 分类模型

这段代码实现了一个简洁的文本分类模型,核心思想是利用 jina-embeddings-v3 文本向量模型计算时输入文本候选标签之间的相似度来进行预测。这是一种非常常见、且高效的零样本文本分类方法。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import logging
logging.basicConfig(level=logging.ERROR)
import torch.nn as nn
import torch
import numpy as np
from transformers import AutoModel
class SentimentAnalysis(nn.Module):
def __init__(self):
super(SentimentAnalysis, self).__init__()
self.encoder = AutoModel.from_pretrained('jina-embeddings-v3', trust_remote_code=True)
def encode(self, texts):
embeddings = self.encoder.encode(texts, convert_to_tensor=True, task='classification')
return embeddings
def similarity(self, text_embeddings, label_embeddings):
return torch.matmul(text_embeddings, label_embeddings.T)
def forward(self, texts, labels):
inputs_embeddings = self.encode(texts) # (3, 1024)
labels_embeddings = self.encode(labels) # (2, 1024)
class_sim = self.similarity(inputs_embeddings, labels_embeddings)
class_ids = torch.argmax(class_sim, axis=-1)
labels = np.take(labels, class_ids.cpu().tolist())
return labels
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
estimator = SentimentAnalysis().to(device)
texts = [('苹果公司今天发布了全新的 iPhone 15 Pro,这款手机配备了更强大的 A17 芯片和全新的摄像头系统,支持 8K 视频拍摄。', ['科技新闻', '体育赛事', '娱乐动态', '商业财经']),
('昨天晚上,我在家里看了一部科幻电影《星际穿越》,电影里宏大的宇宙场景和深刻的情感主题让我深受感动。', ['科幻电影', '历史纪录片']),
('最近,科学家们在火星表面发现了一种新的矿物质,这种矿物质可能为火星上存在过生命的假设提供了新的线索。', ['太空探索', '娱乐新闻', '科技发明']),
('很失望的一次住宿体验。房间有异味,空调不制冷,而且酒店的位置很偏僻,交通不便。服务态度也很差,投诉后也没有得到解决。不会再来了。', ['好评', '差评'])]
for text, label in texts:
pred = estimator([text], label)
print('文本内容:', text)
print('候选标签:', label)
print('预测标签:', pred)
print('-' * 30)
import logging logging.basicConfig(level=logging.ERROR) import torch.nn as nn import torch import numpy as np from transformers import AutoModel class SentimentAnalysis(nn.Module): def __init__(self): super(SentimentAnalysis, self).__init__() self.encoder = AutoModel.from_pretrained('jina-embeddings-v3', trust_remote_code=True) def encode(self, texts): embeddings = self.encoder.encode(texts, convert_to_tensor=True, task='classification') return embeddings def similarity(self, text_embeddings, label_embeddings): return torch.matmul(text_embeddings, label_embeddings.T) def forward(self, texts, labels): inputs_embeddings = self.encode(texts) # (3, 1024) labels_embeddings = self.encode(labels) # (2, 1024) class_sim = self.similarity(inputs_embeddings, labels_embeddings) class_ids = torch.argmax(class_sim, axis=-1) labels = np.take(labels, class_ids.cpu().tolist()) return labels if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') estimator = SentimentAnalysis().to(device) texts = [('苹果公司今天发布了全新的 iPhone 15 Pro,这款手机配备了更强大的 A17 芯片和全新的摄像头系统,支持 8K 视频拍摄。', ['科技新闻', '体育赛事', '娱乐动态', '商业财经']), ('昨天晚上,我在家里看了一部科幻电影《星际穿越》,电影里宏大的宇宙场景和深刻的情感主题让我深受感动。', ['科幻电影', '历史纪录片']), ('最近,科学家们在火星表面发现了一种新的矿物质,这种矿物质可能为火星上存在过生命的假设提供了新的线索。', ['太空探索', '娱乐新闻', '科技发明']), ('很失望的一次住宿体验。房间有异味,空调不制冷,而且酒店的位置很偏僻,交通不便。服务态度也很差,投诉后也没有得到解决。不会再来了。', ['好评', '差评'])] for text, label in texts: pred = estimator([text], label) print('文本内容:', text) print('候选标签:', label) print('预测标签:', pred) print('-' * 30)
import logging
logging.basicConfig(level=logging.ERROR)
import torch.nn as nn
import torch
import numpy as np
from transformers import AutoModel


class SentimentAnalysis(nn.Module):

    def __init__(self):
        super(SentimentAnalysis, self).__init__()
        self.encoder = AutoModel.from_pretrained('jina-embeddings-v3', trust_remote_code=True)

    def encode(self, texts):
        embeddings = self.encoder.encode(texts, convert_to_tensor=True, task='classification')
        return embeddings

    def similarity(self, text_embeddings, label_embeddings):
        return torch.matmul(text_embeddings, label_embeddings.T)

    def forward(self, texts, labels):
        inputs_embeddings = self.encode(texts)   # (3, 1024)
        labels_embeddings = self.encode(labels)  # (2, 1024)
        class_sim = self.similarity(inputs_embeddings, labels_embeddings)
        class_ids = torch.argmax(class_sim, axis=-1)
        labels = np.take(labels, class_ids.cpu().tolist())
        return labels


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    estimator = SentimentAnalysis().to(device)

    texts = [('苹果公司今天发布了全新的 iPhone 15 Pro,这款手机配备了更强大的 A17 芯片和全新的摄像头系统,支持 8K 视频拍摄。', ['科技新闻', '体育赛事', '娱乐动态', '商业财经']),
             ('昨天晚上,我在家里看了一部科幻电影《星际穿越》,电影里宏大的宇宙场景和深刻的情感主题让我深受感动。', ['科幻电影', '历史纪录片']),
             ('最近,科学家们在火星表面发现了一种新的矿物质,这种矿物质可能为火星上存在过生命的假设提供了新的线索。', ['太空探索', '娱乐新闻', '科技发明']),
             ('很失望的一次住宿体验。房间有异味,空调不制冷,而且酒店的位置很偏僻,交通不便。服务态度也很差,投诉后也没有得到解决。不会再来了。', ['好评', '差评'])]

    for text, label in texts:
        pred = estimator([text], label)
        print('文本内容:', text)
        print('候选标签:', label)
        print('预测标签:', pred)
        print('-' * 30)
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
文本内容: 苹果公司今天发布了全新的 iPhone 15 Pro,这款手机配备了更强大的 A17 芯片和全新的摄像头系统,支持 8K 视频拍摄。
候选标签: ['科技新闻', '体育赛事', '娱乐动态', '商业财经']
预测标签: ['科技新闻']
------------------------------
文本内容: 昨天晚上,我在家里看了一部科幻电影《星际穿越》,电影里宏大的宇宙场景和深刻的情感主题让我深受感动。
候选标签: ['科幻电影', '历史纪录片']
预测标签: ['科幻电影']
------------------------------
文本内容: 最近,科学家们在火星表面发现了一种新的矿物质,这种矿物质可能为火星上存在过生命的假设提供了新的线索。
候选标签: ['太空探索', '娱乐新闻', '科技发明']
预测标签: ['太空探索']
------------------------------
文本内容: 很失望的一次住宿体验。房间有异味,空调不制冷,而且酒店的位置很偏僻,交通不便。服务态度也很差,投诉后也没有得到解决。不会再来了。
候选标签: ['好评', '差评']
预测标签: ['差评']
------------------------------
文本内容: 苹果公司今天发布了全新的 iPhone 15 Pro,这款手机配备了更强大的 A17 芯片和全新的摄像头系统,支持 8K 视频拍摄。 候选标签: ['科技新闻', '体育赛事', '娱乐动态', '商业财经'] 预测标签: ['科技新闻'] ------------------------------ 文本内容: 昨天晚上,我在家里看了一部科幻电影《星际穿越》,电影里宏大的宇宙场景和深刻的情感主题让我深受感动。 候选标签: ['科幻电影', '历史纪录片'] 预测标签: ['科幻电影'] ------------------------------ 文本内容: 最近,科学家们在火星表面发现了一种新的矿物质,这种矿物质可能为火星上存在过生命的假设提供了新的线索。 候选标签: ['太空探索', '娱乐新闻', '科技发明'] 预测标签: ['太空探索'] ------------------------------ 文本内容: 很失望的一次住宿体验。房间有异味,空调不制冷,而且酒店的位置很偏僻,交通不便。服务态度也很差,投诉后也没有得到解决。不会再来了。 候选标签: ['好评', '差评'] 预测标签: ['差评'] ------------------------------
文本内容: 苹果公司今天发布了全新的 iPhone 15 Pro,这款手机配备了更强大的 A17 芯片和全新的摄像头系统,支持 8K 视频拍摄。
候选标签: ['科技新闻', '体育赛事', '娱乐动态', '商业财经']
预测标签: ['科技新闻']
------------------------------
文本内容: 昨天晚上,我在家里看了一部科幻电影《星际穿越》,电影里宏大的宇宙场景和深刻的情感主题让我深受感动。
候选标签: ['科幻电影', '历史纪录片']
预测标签: ['科幻电影']
------------------------------
文本内容: 最近,科学家们在火星表面发现了一种新的矿物质,这种矿物质可能为火星上存在过生命的假设提供了新的线索。
候选标签: ['太空探索', '娱乐新闻', '科技发明']
预测标签: ['太空探索']
------------------------------
文本内容: 很失望的一次住宿体验。房间有异味,空调不制冷,而且酒店的位置很偏僻,交通不便。服务态度也很差,投诉后也没有得到解决。不会再来了。
候选标签: ['好评', '差评']
预测标签: ['差评']
------------------------------

2. 评估数据

前面我们构造了几个数据来看看模型的分类能力,下面代码预处理了 7000+ 条酒店评论数据,用于后续对模型进行评估。具体步骤如下:

  1. 读取数据:从 'comments.csv' 文件中读取数据,并将其加载为一个 Pandas DataFrame。
  2. 数据清洗:通过 dropna() 函数删除任何包含缺失值(NaN)的行,确保数据的完整性。
  3. 类别分布统计:使用 Counter 来统计评论的类别分布,即查看“好评”(label=1)和“差评”(label=0)各自的数量。
  4. 数据分离与合并
    • 从数据中分离出好评和差评两部分。
    • 将这两部分重新合并成一个样本集。
    • 将“好评”和“差评”标签替换为更易懂的中文标签:“好评”和“差评”。
  5. 保存数据:将处理后的数据以 NumPy 数组的形式保存到 'samples.pkl' 文件中,以便后续使用。
  6. 输出信息:打印处理后的数据集的维度(即样本数量和特征数量),以及类别分布。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
def demo():
data = pd.read_csv('comments.csv')
data = data.dropna()
print('类别分布:', Counter(data['label']))
negative = data[data['label'] == 0]
positive = data[data['label'] == 1]
samples = pd.concat([negative, positive])
samples['label'] = np.where(samples['label'] == 1, '好评', '差评')
pickle.dump(samples.to_numpy(), open('samples.pkl', 'wb'))
print('评估数据:', samples.shape)
if __name__ == '__main__':
demo()
import pandas as pd import pickle from sklearn.model_selection import train_test_split import numpy as np import matplotlib.pyplot as plt from collections import Counter def demo(): data = pd.read_csv('comments.csv') data = data.dropna() print('类别分布:', Counter(data['label'])) negative = data[data['label'] == 0] positive = data[data['label'] == 1] samples = pd.concat([negative, positive]) samples['label'] = np.where(samples['label'] == 1, '好评', '差评') pickle.dump(samples.to_numpy(), open('samples.pkl', 'wb')) print('评估数据:', samples.shape) if __name__ == '__main__': demo()
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter


def demo():

    data = pd.read_csv('comments.csv')
    data = data.dropna()
    print('类别分布:', Counter(data['label']))

    negative = data[data['label'] == 0]
    positive = data[data['label'] == 1]

    samples = pd.concat([negative, positive])
    samples['label'] = np.where(samples['label'] == 1, '好评', '差评')

    pickle.dump(samples.to_numpy(), open('samples.pkl', 'wb'))
    print('评估数据:', samples.shape)

if __name__ == '__main__':
    demo()
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
类别分布: Counter({1: 5322, 0: 2443})
评估数据: (7765, 2)
类别分布: Counter({1: 5322, 0: 2443}) 评估数据: (7765, 2)
类别分布: Counter({1: 5322, 0: 2443})
评估数据: (7765, 2)

3. 模型评估

下面这段代码的主要功能是用预处理好的酒店评论数据对情感分析模型进行评估,计算模型的准确度。具体步骤如下:

  1. 设备选择:首先判断是否可以使用 GPU 设备(通过 torch.cuda.is_available()),如果可以则使用 GPU,否则使用 CPU。
  2. 加载模型:初始化一个情感分析模型 SentimentAnalysis 并将其加载到选定的设备上。
  3. 加载数据:从之前保存的 'samples.pkl' 文件中加载数据(酒店评论数据),该数据已经经过预处理。
  4. 定义数据加载器(DataLoader)
    • collate_fn 函数用于将一个批次的输入数据和标签打包成适合模型处理的格式。
    • 使用 DataLoader 来加载数据,批次大小为 32,数据顺序随机打乱。
  5. 评估过程
    • 使用 tqdm 创建进度条,显示评估过程中的准确度。
    • 遍历数据加载器,逐批次获取评论数据和标签,使用模型进行预测,并将预测结果与实际标签进行比较。
    • 计算当前的准确度,并实时更新进度条上的准确度值。
  6. 输出结果:在评估结束时,关闭进度条。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import logging
logging.basicConfig(level=logging.ERROR)
import pickle
import torch
from torch.utils.data import DataLoader
import tqdm
from sklearn.metrics import accuracy_score
from estimator import SentimentAnalysis
def demo():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
estimator = SentimentAnalysis().to(device)
samples = pickle.load(open('samples.pkl', 'rb'))
def collate_fn(batch_data):
inputs, labels = [], []
for label, input in batch_data:
inputs.append(input)
labels.append(label)
return inputs, labels
dataloader = DataLoader(samples, batch_size=32, shuffle=True, collate_fn=collate_fn)
progress = tqdm.tqdm(range(len(dataloader)), desc='Evaluate Acc: %.3f' % 0)
y_true, y_pred = [], []
for inputs, true_labels in dataloader:
pred_labels = estimator(inputs, ['好评', '差评'])
y_true.extend(true_labels)
y_pred.extend(pred_labels)
accuracy = accuracy_score(y_true, y_pred)
progress.set_description('Evaluate Acc: %.3f' % accuracy)
progress.update()
progress.close()
# watch -n 1 nvidia-smi
if __name__ == '__main__':
demo()
import logging logging.basicConfig(level=logging.ERROR) import pickle import torch from torch.utils.data import DataLoader import tqdm from sklearn.metrics import accuracy_score from estimator import SentimentAnalysis def demo(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') estimator = SentimentAnalysis().to(device) samples = pickle.load(open('samples.pkl', 'rb')) def collate_fn(batch_data): inputs, labels = [], [] for label, input in batch_data: inputs.append(input) labels.append(label) return inputs, labels dataloader = DataLoader(samples, batch_size=32, shuffle=True, collate_fn=collate_fn) progress = tqdm.tqdm(range(len(dataloader)), desc='Evaluate Acc: %.3f' % 0) y_true, y_pred = [], [] for inputs, true_labels in dataloader: pred_labels = estimator(inputs, ['好评', '差评']) y_true.extend(true_labels) y_pred.extend(pred_labels) accuracy = accuracy_score(y_true, y_pred) progress.set_description('Evaluate Acc: %.3f' % accuracy) progress.update() progress.close() # watch -n 1 nvidia-smi if __name__ == '__main__': demo()
import logging
logging.basicConfig(level=logging.ERROR)
import pickle
import torch
from torch.utils.data import DataLoader
import tqdm
from sklearn.metrics import accuracy_score
from estimator import SentimentAnalysis


def demo():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    estimator = SentimentAnalysis().to(device)
    samples = pickle.load(open('samples.pkl', 'rb'))

    def collate_fn(batch_data):
        inputs, labels = [], []
        for label, input in batch_data:
            inputs.append(input)
            labels.append(label)
        return inputs, labels

    dataloader = DataLoader(samples, batch_size=32, shuffle=True, collate_fn=collate_fn)
    progress = tqdm.tqdm(range(len(dataloader)), desc='Evaluate Acc: %.3f' % 0)

    y_true, y_pred = [], []
    for inputs, true_labels in dataloader:
        pred_labels = estimator(inputs, ['好评', '差评'])
        y_true.extend(true_labels)
        y_pred.extend(pred_labels)
        accuracy = accuracy_score(y_true, y_pred)
        progress.set_description('Evaluate Acc: %.3f' % accuracy)
        progress.update()
    progress.close()


# watch -n 1 nvidia-smi
if __name__ == '__main__':
    demo()
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
Evaluate Acc: 0.905: 100%|████████████████████| 486/486 [02:57<00:00, 2.74it/s]
Evaluate Acc: 0.905: 100%|████████████████████| 486/486 [02:57<00:00, 2.74it/s]
Evaluate Acc: 0.905: 100%|████████████████████| 486/486 [02:57<00:00,  2.74it/s]
未经允许不得转载:一亩三分地 » 基于语义嵌入 Zero-Shot Learning 文本分类
评论 (0)

4 + 7 =