基于类别均值的分类方法(Nearest Mean Classification,NMC)是一种简单的监督学习分类算法。它的核心思想是通过计算每个类别的样本均值向量,然后将新样本归类到与其均值向量最相似的类别。这种方法在某些特定场景下表现出较好的分类效果,尤其是在类别分布较为集中且特征维度较低的情况下。

这种分类方法简单,也具有一定的局限性。例如:它对类别分布的假设较强,假设每个类别的样本分布是围绕均值向量的球形分布。如果类别内部的样本分布较为复杂,那么该方法可能无法准确捕捉类别特征,从而导致分类效果不佳。
1. 准备数据
import pickle from collections import Counter import glob import os def demo(): fnames = glob.glob('ThuCnews/*.txt') for fname in fnames: name, ext = os.path.splitext(os.path.basename(fname)) reports, targets = [], [] with open(fname) as file: for line in file: target, report = line.strip().split('\t') reports.append((target, report)) targets.append(target) print(name, Counter(targets)) pickle.dump(reports, open(f'{name}.pkl', 'wb')) if __name__ == '__main__': demo()
test Counter({'体育': 1000, '娱乐': 1000, '家居': 1000, '房产': 1000, '教育': 1000, '时尚': 1000, '时政': 1000, '游戏': 1000, '科技': 1000, '财经': 1000}) valid Counter({'体育': 500, '娱乐': 500, '家居': 500, '房产': 500, '教育': 500, '时尚': 500, '时政': 500, '游戏': 500, '科技': 500, '财经': 500}) train Counter({'体育': 5000, '娱乐': 5000, '家居': 5000, '房产': 5000, '教育': 5000, '时尚': 5000, '时政': 5000, '游戏': 5000, '科技': 5000, '财经': 5000})
2. 类别向量
由于训练集数量较多,计算较为耗时。因此代码在计算类别均值时,使用的是验证集 5000 样本来计算。
import logging logging.basicConfig(level=logging.ERROR) import pickle import numpy as np import torch.nn as nn import torch from transformers import BertModel from transformers import BertTokenizer from transformers import AutoModel from torch.utils.data import DataLoader from torch.utils.data import TensorDataset import tqdm from collections import defaultdict from sklearn.metrics import accuracy_score class SentimentAnalysis(nn.Module): def __init__(self): super(SentimentAnalysis, self).__init__() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('device:', device) self.encoder = AutoModel.from_pretrained('jina-embeddings-v3', trust_remote_code=True).to(device) def encode(self, texts): embeddings = self.encoder.encode(texts, convert_to_tensor=True, truncate_dim=1024, task='classification') return embeddings def fit(self, train_data): def collate_fn(batch_data): batch_targets, batch_reports = [], [] for target, report in batch_data: batch_targets.append(target) batch_reports.append(report) return batch_targets, batch_reports dataloader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=collate_fn) progress = tqdm.tqdm(range(len(dataloader)), desc='Train') class_numbers, class_vectors = {}, {} for batch_targets, batch_reports in dataloader: with torch.no_grad(): batch_vectors = self.encode(batch_reports) for target, vector in zip(batch_targets, batch_vectors): if target not in class_vectors: class_vectors[target] = torch.zeros(1024) if target not in class_numbers: class_numbers[target] = 0 class_vectors[target] += vector.cpu() class_numbers[target] += 1 progress.update() progress.close() self.targets, self.vectors = [], [] for class_name, class_number in class_numbers.items(): self.targets.append( class_name ) self.vectors.append( class_vectors[class_name]/class_number ) self.vectors = torch.stack(self.vectors) print(self.vectors.shape) def compute_similarity(self, vectors): # self.vectors: (labels, 1024) # vectors (batch, 1024) return torch.matmul(vectors, self.vectors.T) def predict(self, texts): text_vectors = self.encode(texts) y_preds = self.compute_similarity(text_vectors) y_preds = torch.argmax(y_preds, dim=-1) y_preds = [ self.targets[idx] for idx in y_preds.cpu().tolist() ] return y_preds def save(self): pickle.dump(self.vectors, open('model/vectors.pkl', 'wb')) pickle.dump(self.targets, open('model/targets.pkl', 'wb')) @classmethod def load(self): estimator = SentimentAnalysis() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') estimator.encoder = AutoModel.from_pretrained('jina-embeddings-v3', trust_remote_code=True).to(device) estimator.vectors = pickle.load(open('model/vectors.pkl', 'rb')).to(device) estimator.targets = pickle.load(open('model/targets.pkl', 'rb')) return estimator def train(): estimator = SentimentAnalysis() # train_data = pickle.load(open('train.pkl', 'rb')) valid_data = pickle.load(open('valid.pkl', 'rb')) estimator.fit(valid_data) estimator.save() # watch -n 1 nvidia-smi # top if __name__ == '__main__': train()
3. 查看效果
from model import SentimentAnalysis import pickle import torch from torch.utils.data import DataLoader import tqdm from sklearn.metrics import accuracy_score def evaluate(): estimator = SentimentAnalysis.load() test_data = pickle.load(open('test.pkl', 'rb')) def collate_fn(batch_data): batch_targets, batch_reports = [], [] for target, report in batch_data: batch_targets.append(target) batch_reports.append(report) return batch_targets, batch_reports dataloader = DataLoader(test_data, batch_size=2, collate_fn=collate_fn) progress = tqdm.tqdm(range(len(dataloader)), desc='Evaluate Acc: %.3f' % 0) y_true, y_pred = [], [] for true_labels, batch_targets in dataloader: pred_labels = estimator.predict(batch_targets) y_true.extend(true_labels) y_pred.extend(pred_labels) # 计算准确率 accuracy = accuracy_score(y_true, y_pred) progress.set_description('Evaluate Acc: %.3f' % accuracy) progress.update() progress.close() if __name__ == '__main__': evaluate()
下面是 10000 样本下的准确率:
Evaluate Acc: 0.898: 100%|██████████████████| 5000/5000 [14:29<00:00, 5.75it/s]